@@ -21,7 +21,6 @@ package org.apache.comet
2121
2222import org .scalatest .funsuite .AnyFunSuite
2323
24- import org .apache .arrow .vector .{VarCharVector , ViewVarCharVector }
2524import org .apache .spark .sql .catalyst .InternalRow
2625import org .apache .spark .sql .catalyst .expressions .{BoundReference , Concat , Expression , LeafExpression , Length , Literal , Nondeterministic , RegExpReplace , RLike , Unevaluable , Upper }
2726import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , CodegenFallback , ExprCode }
@@ -30,6 +29,11 @@ import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
3029import org .apache .comet .udf .CometBatchKernelCodegen
3130import org .apache .comet .udf .CometBatchKernelCodegen .ArrowColumnSpec
3231
32+ // Resolve Arrow vector classes through the codegen object so tests see the same `Class` objects
33+ // the shaded `common` module sees. A direct `classOf[org.apache.arrow.vector.VarCharVector]` here
34+ // would be the unshaded class from the test classpath, which is not `==` to the shaded class the
35+ // production pattern-matches against.
36+
3337/**
3438 * Generated-source inspection tests. These exercise `CometBatchKernelCodegen.generateSource` and
3539 * assert on the emitted Java directly, without invoking Janino. The goal is to catch regressions
@@ -48,8 +52,13 @@ import org.apache.comet.udf.CometBatchKernelCodegen.ArrowColumnSpec
4852 */
4953class CometCodegenSourceSuite extends AnyFunSuite {
5054
51- private val nullableString = ArrowColumnSpec (classOf [VarCharVector ], nullable = true )
52- private val nonNullableString = ArrowColumnSpec (classOf [VarCharVector ], nullable = false )
55+ private val varCharVectorClass =
56+ CometBatchKernelCodegen .vectorClassBySimpleName(" VarCharVector" )
57+ private val viewVarCharVectorClass =
58+ CometBatchKernelCodegen .vectorClassBySimpleName(" ViewVarCharVector" )
59+
60+ private val nullableString = ArrowColumnSpec (varCharVectorClass, nullable = true )
61+ private val nonNullableString = ArrowColumnSpec (varCharVectorClass, nullable = false )
5362
5463 private def gen (
5564 expr : org.apache.spark.sql.catalyst.expressions.Expression ,
@@ -94,7 +103,7 @@ class CometCodegenSourceSuite extends AnyFunSuite {
94103 }
95104
96105 test(" ViewVarCharVector getUTF8String branches inline vs referenced without allocating" ) {
97- val viewSpec = ArrowColumnSpec (classOf [ ViewVarCharVector ] , nullable = true )
106+ val viewSpec = ArrowColumnSpec (viewVarCharVectorClass , nullable = true )
98107 val expr = Length (BoundReference (0 , StringType , nullable = true ))
99108 val src = gen(expr, viewSpec)
100109 // The view case reads the 16-byte view entry and picks inline vs referenced data without a
@@ -177,8 +186,8 @@ class CometCodegenSourceSuite extends AnyFunSuite {
177186 // being null would skip evaluation, but Concat's null handling differs). Expect the
178187 // default path without the `if (colX.isNull(i) || colY.isNull(i))` wrapper, letting Spark's
179188 // own `ev.code` handle nulls correctly.
180- val nullable1 = ArrowColumnSpec (classOf [ VarCharVector ] , nullable = true )
181- val nullable2 = ArrowColumnSpec (classOf [ VarCharVector ] , nullable = true )
189+ val nullable1 = ArrowColumnSpec (varCharVectorClass , nullable = true )
190+ val nullable2 = ArrowColumnSpec (varCharVectorClass , nullable = true )
182191 val expr = RLike (
183192 Concat (
184193 Seq (
0 commit comments