@@ -116,6 +116,35 @@ class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPla
116116 s " expected codegen dispatcher activity, got $after" )
117117 }
118118
119+ /**
120+ * Stronger form of [[assertCodegenDidWork ]] for composition tests. Asserts that the full
121+ * expression subtree compiled into at most one kernel. The "one JNI crossing per nesting level"
122+ * alternative (the PR description's foil) would produce one `(bytes, specs)` cache entry per
123+ * nested sub-expression, so `compileCount` would be N and the cache would grow by N after the
124+ * first batch. Asserting `compileCount <= 1` and `cacheSize` growth `<= 1` directly falsifies
125+ * that shape.
126+ *
127+ * Uses `<=` rather than `==` because the compile cache is JVM-wide and shared across tests; a
128+ * prior test that already compiled the same `(expression bytes, input schema)` pair will make
129+ * this run a cache hit (`compileCount == 0`). The dispatcher-activity check guards against a
130+ * silent fallback where the query runs through Spark and the first two assertions pass
131+ * vacuously.
132+ */
133+ private def assertOneKernelForSubtree (f : => Unit ): Unit = {
134+ CometCodegenDispatchUDF .resetStats()
135+ val sizeBefore = CometCodegenDispatchUDF .stats().cacheSize
136+ f
137+ val after = CometCodegenDispatchUDF .stats()
138+ assert(
139+ after.compileCount <= 1 ,
140+ s " expected <= 1 compile for the composed subtree, got $after" )
141+ val grew = after.cacheSize - sizeBefore
142+ assert(grew <= 1 , s " expected cache to grow by <= 1 entry, grew by $grew; stats= $after" )
143+ assert(
144+ after.compileCount + after.cacheHitCount >= 1 ,
145+ s " expected codegen dispatcher activity, got $after" )
146+ }
147+
119148 /**
120149 * Assert that the dispatcher's compile cache contains a kernel compiled for the given input
121150 * Arrow vector classes (in ordinal order) and output Spark `DataType`. This is a specialization
@@ -473,14 +502,18 @@ class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPla
473502 }
474503
475504 test(" codegen: three-deep ScalaUDF composition lvl3(lvl2(lvl1(s)))" ) {
476- // Three user UDFs stacked in one tree: String -> String -> String -> Int. Single Janino
477- // compile, three `ctx.addReferenceObj` calls in the fused method. Verifies the dispatcher
478- // doesn't flatten or reorder the chain.
505+ // Three user UDFs stacked in one tree: String -> String -> String -> Int. The fused kernel
506+ // carries three `ctx.addReferenceObj` calls. `assertOneKernelForSubtree` asserts that the
507+ // whole chain collapses into a single compile rather than one per nesting level.
508+ // Input rows intentionally exclude nulls: per-batch nullability is a cache-key dimension
509+ // (`nullable()` reads `getNullCount != 0`), so a null-present batch compiles a second kernel
510+ // specialized for `nullable=true`. Null handling through composed UDFs is covered by the
511+ // other composition tests above.
479512 spark.udf.register(" lvl1" , (s : String ) => if (s == null ) null else s.toUpperCase)
480513 spark.udf.register(" lvl2" , (s : String ) => if (s == null ) null else s.reverse)
481514 spark.udf.register(" lvl3" , (s : String ) => if (s == null ) - 1 else s.length)
482- withSubjects(" abc" , null , " hello world" , " x" ) {
483- assertCodegenDidWork {
515+ withSubjects(" abc" , " hello world" , " x" ) {
516+ assertOneKernelForSubtree {
484517 checkSparkAnswerAndOperator(sql(" SELECT lvl3(lvl2(lvl1(s))) FROM t" ))
485518 }
486519 assertKernelSignaturePresent(Seq (classOf [VarCharVector ]), IntegerType )
@@ -490,14 +523,16 @@ class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPla
490523 test(" codegen: multi-column ScalaUDF composition join(upperU(c1), lowerU(c2))" ) {
491524 // One multi-arg user UDF consuming two other user UDFs, each on a different input column.
492525 // The bound tree has two BoundReferences, and the kernel is specialized on two VarCharVector
493- // columns. Proves multi-column composition of pure user UDFs works with zero Spark helpers.
526+ // columns. `assertOneKernelForSubtree` asserts that the two-branch composition fuses into a
527+ // single kernel rather than one per branch or one per UDF.
528+ // Input rows intentionally exclude nulls (see note on the three-deep test above).
494529 spark.udf.register(" upperU" , (s : String ) => if (s == null ) null else s.toUpperCase)
495530 spark.udf.register(" lowerU" , (s : String ) => if (s == null ) null else s.toLowerCase)
496531 spark.udf.register(
497532 " joinU" ,
498533 (a : String , b : String ) => if (a == null || b == null ) null else s " $a- $b" )
499- withTwoStringCols((" Abc" , " XYZ" ), (" Foo" , null ), (null , " Bar" ), (" Hi" , " Lo" )) {
500- assertCodegenDidWork {
534+ withTwoStringCols((" Abc" , " XYZ" ), (" Foo" , " bar " ), (" baz " , " Bar" ), (" Hi" , " Lo" )) {
535+ assertOneKernelForSubtree {
501536 checkSparkAnswerAndOperator(sql(" SELECT joinU(upperU(c1), lowerU(c2)) FROM t" ))
502537 }
503538 assertKernelSignaturePresent(
0 commit comments