@@ -36,13 +36,11 @@ import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGener
3636import org .apache .comet .udf .codegen .CometScalaUDFCodegen
3737
3838/**
39- * Randomized tests for the Arrow-direct codegen dispatcher. Schema-driven coverage of every input
40- * vector class via random parquet files, plus a decimal precision-scale sweep across the
41- * `Decimal.MAX_LONG_DIGITS=18` boundary at varying null densities.
42- *
43- * Extends [[CometTestBase ]] (not [[CometFuzzTestBase ]]) and inlines the random parquet setup so
44- * tests run once. The base's three-way cross-product (`shuffle` x `nativeC2R`) does not change
45- * the codegen path for projection-only queries, so it would be runtime cost without coverage.
39+ * Randomized tests for the Arrow-direct codegen dispatcher: schema-driven coverage of every input
40+ * vector class, plus a decimal precision-scale sweep across the `Decimal.MAX_LONG_DIGITS=18`
41+ * boundary at varying null densities. Extends [[CometTestBase ]] (not [[CometFuzzTestBase ]])
42+ * because the base's `shuffle` x `nativeC2R` cross-product `test()` override is irrelevant for
43+ * projection-only queries.
4644 */
4745class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlanHelper {
4846
@@ -102,6 +100,9 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
102100 1000 ,
103101 dataGenOptions)
104102 }
103+
104+ spark.read.parquet(mixedTypesFilename).createOrReplaceTempView(" t1" )
105+ spark.read.parquet(nestedTypesFilename).createOrReplaceTempView(" t2" )
105106 }
106107
107108 protected override def afterAll (): Unit = {
@@ -112,7 +113,8 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
112113
113114 private val RowCount : Int = 512
114115 private val nullDensities : Seq [Double ] = Seq (0.0 , 0.1 , 0.5 , 1.0 )
115- // (precision, scale) pairs spanning both sides of the MAX_LONG_DIGITS=18 boundary.
116+ // (precision, scale) shapes spanning both sides of `Decimal.MAX_LONG_DIGITS=18`: small short,
117+ // boundary short with varying scale, just-past-boundary long, and max decimal128.
116118 private val decimalShapes : Seq [(Int , Int )] = Seq ((9 , 2 ), (18 , 0 ), (18 , 9 ), (19 , 0 ), (38 , 10 ))
117119
118120 override protected def sparkConf : SparkConf =
@@ -165,16 +167,12 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
165167 * Identity-Int UDF for the cardinality-based complex probe. One UDF covers every Array and Map
166168 * column, regardless of element type.
167169 *
168- * Avoiding `Seq[T]` / `Map[K, V]` materialization is deliberate: Spark's
169- * `org.apache.spark.sql.catalyst.expressions.objects.MapObjects` codegen reads each element via
170- * `getLong`/`getFloat`/etc. unconditionally and only checks `isNullAt` afterward to decide
171- * whether to wrap the value in `Option` or null. On null positions of a dictionary-encoded
172- * primitive Arrow vector the underlying ID buffer holds uninitialized bytes, and
173- * `decodeToLong/decodeToFloat` against those garbage IDs throws
174- * `ArrayIndexOutOfBoundsException`. The buggy code is in Spark; the failure reproduces in pure
175- * Spark execution (no Comet on the trace), so `checkSparkAnswerAndOperator` cannot compute the
176- * baseline answer. `cardinality(col)` exercises the kernel's `getArray`/`getMap` length read
177- * while bypassing the element deserializer entirely.
170+ * Avoids `Seq[T]` / `Map[K, V]` UDF arg materialization: Spark's `MapObjects.doGenCode` reads
171+ * each element unconditionally and null-checks afterward, so on null positions of a
172+ * dictionary-encoded primitive Arrow vector the garbage ID buffer feeds
173+ * `dictionary.decodeToLong/decodeToFloat` and throws `ArrayIndexOutOfBoundsException`. Bug
174+ * reproduces in pure Spark; `cardinality(col)` exercises `getArray`/`getMap` without entering
175+ * the element deserializer.
178176 */
179177 private lazy val cardinalityProbeUdf : String = {
180178 val name = " sz_complex"
@@ -183,9 +181,8 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
183181 }
184182
185183 test(" identity ScalaUDF over every primitive column" ) {
186- val df = spark.read.parquet(mixedTypesFilename)
187- df.createOrReplaceTempView(" t1" )
188- val primitiveFields = df.schema.fields.filterNot(f => isComplexType(f.dataType))
184+ val primitiveFields =
185+ spark.table(" t1" ).schema.fields.filterNot(f => isComplexType(f.dataType))
189186 assert(primitiveFields.nonEmpty, " expected at least one primitive column in random schema" )
190187 for (field <- primitiveFields) {
191188 val udfName = s " id_ ${field.name}"
@@ -203,36 +200,26 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
203200 }
204201
205202 test(" complex-probe ScalaUDF on every complex column" ) {
206- val df = spark.read.parquet(mixedTypesFilename)
207- df.createOrReplaceTempView(" t1" )
208- val complexFields = df.schema.fields.filter(f => isComplexType(f.dataType))
203+ val complexFields = spark.table(" t1" ).schema.fields.filter(f => isComplexType(f.dataType))
209204 assert(complexFields.nonEmpty, " expected at least one complex column in random schema" )
210205 for (field <- complexFields) {
211206 probeComplexColumn(field, viewName = " t1" )
212207 }
213208 }
214209
215210 test(" complex-probe ScalaUDF on top-level columns of deeply nested schema" ) {
216- val df = spark.read.parquet(nestedTypesFilename)
217- df.createOrReplaceTempView(" t2" )
218- for (field <- df.schema.fields) {
211+ for (field <- spark.table(" t2" ).schema.fields) {
219212 probeComplexColumn(field, viewName = " t2" )
220213 }
221214 }
222215
223216 /**
224- * Element-level fuzz for nested array reads. For every `Array<primitive>` column in the random
225- * schema, runs `id_X(array_max(col))` so Spark's `ArrayMax.doGenCode` walks every element of
226- * every row and calls the kernel's nested element getter
227- * (`getInt`/`getLong`/`getDecimal`/etc.). The cardinality probe deliberately avoids element
228- * materialization, so without this test no fuzz coverage exists on the element-getter paths the
229- * unsafe-access optimization would touch. `array_max` is comparison-only on every primitive
230- * Spark supports, so one expression covers all 14 element types.
217+ * Element-level fuzz for nested array reads: `ArrayMax.doGenCode` walks every element of every
218+ * row, calling the kernel's nested element getter — the path the unsafe-getter optimization
219+ * touches and which the cardinality probe deliberately skips.
231220 */
232221 test(" array_max element fuzz: every Array<primitive> column" ) {
233- val df = spark.read.parquet(mixedTypesFilename)
234- df.createOrReplaceTempView(" t1" )
235- val arrayPrimitiveFields = df.schema.fields.filter {
222+ val arrayPrimitiveFields = spark.table(" t1" ).schema.fields.filter {
236223 case StructField (_, ArrayType (elemDt, _), _, _) if ! isComplexType(elemDt) => true
237224 case _ => false
238225 }
@@ -256,17 +243,12 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
256243 }
257244
258245 /**
259- * Element-level fuzz for map key and value reads. `map_keys(col)` / `map_values(col)` produce
260- * arrays the kernel walks via Spark's `ArrayMax`, exercising the map's child key/value getter.
261- * The leaf primitive read is structurally the same as in the array element fuzz, but the parent
262- * offset chain (MapVector -> entries StructVector -> child) differs, so a buggy unsafe getter
263- * that mishandled the map's per-row offset would slip past the array test alone. Filters to
264- * top-level `Map<primitive, primitive>` columns from the random nested schema.
246+ * Map variant of the array element fuzz: `map_keys` / `map_values` produce arrays the kernel
247+ * walks via `ArrayMax`, exercising the map's per-row offset chain (MapVector -> entries
248+ * StructVector -> child) that the array test alone wouldn't catch.
265249 */
266250 test(" array_max element fuzz: map_keys / map_values on Map<primitive, primitive> columns" ) {
267- val df = spark.read.parquet(nestedTypesFilename)
268- df.createOrReplaceTempView(" t2" )
269- val mapPrimitiveFields = df.schema.fields.filter {
251+ val mapPrimitiveFields = spark.table(" t2" ).schema.fields.filter {
270252 case StructField (_, MapType (kDt, vDt, _), _, _)
271253 if ! isComplexType(kDt) && ! isComplexType(vDt) =>
272254 true
@@ -288,64 +270,44 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
288270 }
289271 }
290272
273+ private def probeCardinality (accessor : String , viewName : String ): Unit = {
274+ assertCodegenRan {
275+ checkSparkAnswerAndOperator(
276+ s " SELECT $cardinalityProbeUdf(cardinality( $accessor)) FROM $viewName" )
277+ }
278+ }
279+
291280 /**
292- * Probes one complex top-level column. ArrayType / MapType go through `cardinality(col)` fed to
293- * the identity-Int probe UDF (see [[cardinalityProbeUdf ]] for the rationale). StructType drills
294- * into each scalar child via `GetStructField` and runs the identity UDF on it; complex children
295- * are recursed via the same dot-path (depth bounded by the schema generator).
281+ * Top-level Array / Map → cardinality probe. Struct → drill into each scalar child via
282+ * `GetStructField`; nested Array / Map sub-fields also get the cardinality probe (depth bound:
283+ * deeper struct-of-struct nesting is skipped to keep the sweep finite).
296284 */
297285 private def probeComplexColumn (field : StructField , viewName : String ): Unit = {
298286 field.dataType match {
299287 case _ : ArrayType | _ : MapType =>
300- assertCodegenRan {
301- checkSparkAnswerAndOperator(
302- s " SELECT $cardinalityProbeUdf(cardinality( ${field.name})) FROM $viewName" )
303- }
288+ probeCardinality(field.name, viewName)
304289
305290 case st : StructType =>
306291 for (subField <- st.fields) {
307292 val accessor = s " ${field.name}. ${subField.name}"
308- if (isComplexType(subField.dataType)) {
309- probeComplexAccessor(subField, accessor, viewName)
310- } else {
311- val udfName = s " id_ ${field.name}_ ${subField.name}"
312- registerIdentityUdfFor(subField.dataType, udfName).foreach { _ =>
313- assertCodegenRan {
314- checkSparkAnswerAndOperator(s " SELECT $udfName( $accessor) FROM $viewName" )
293+ subField.dataType match {
294+ case _ : ArrayType | _ : MapType => probeCardinality(accessor, viewName)
295+ case dt if ! isComplexType(dt) =>
296+ val udfName = s " id_ ${field.name}_ ${subField.name}"
297+ registerIdentityUdfFor(dt, udfName).foreach { _ =>
298+ assertCodegenRan {
299+ checkSparkAnswerAndOperator(s " SELECT $udfName( $accessor) FROM $viewName" )
300+ }
315301 }
316- }
302+ case _ => // deeper struct nesting skipped
317303 }
318304 }
319305
320- case _ => // not complex; caller filtered
321- }
322- }
323-
324- /**
325- * Probes a complex sub-field reached via dot access (e.g. `s.items` for an inner array). The
326- * dispatcher's bound tree carries `Cardinality(GetStructField(...))` around the kernel's
327- * complex column read.
328- */
329- private def probeComplexAccessor (
330- field : StructField ,
331- accessor : String ,
332- viewName : String ): Unit = {
333- field.dataType match {
334- case _ : ArrayType | _ : MapType =>
335- assertCodegenRan {
336- checkSparkAnswerAndOperator(
337- s " SELECT $cardinalityProbeUdf(cardinality( $accessor)) FROM $viewName" )
338- }
339- case _ => // deeper struct nesting skipped to keep the sweep bounded
306+ case _ =>
340307 }
341308 }
342309
343- /**
344- * Randomized decimal identity UDF. Spans both sides of the `Decimal.MAX_LONG_DIGITS` (18)
345- * boundary so each test hits one of the two specialized branches in the generated `getDecimal`
346- * getter. Precisions are chosen to exercise: small short-precision, boundary short-precision
347- * with varying scale, just-past-boundary long precision, and the max decimal128 precision.
348- */
310+ /** Random `BigDecimal` values fitting `(precision, scale)`, with `nullDensity` of them null. */
349311 private def generateDecimals (
350312 seed : Long ,
351313 precision : Int ,
@@ -389,11 +351,6 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
389351 (precision, scale) <- decimalShapes
390352 } {
391353 test(s " decimal identity precision= $precision scale= $scale nullDensity= $density" ) {
392- // Reuse one registered UDF name across iterations; Spark replaces by name. The Scala-side
393- // signature uses `BigDecimal`, which Spark encodes as DecimalType(38, 18); an implicit Cast
394- // from the column's DecimalType to the UDF's parameter type runs inside Spark's generated
395- // code, but the column read still goes through our kernel's `getDecimal` which is the path
396- // we're fuzzing.
397354 spark.udf.register(" dec_id_fuzz" , (d : java.math.BigDecimal ) => d)
398355 val seed = ((precision * 31L ) + scale) * 31L + density.hashCode
399356 val values = generateDecimals(seed, precision, scale, density)
0 commit comments