Skip to content

Commit 41fc046

Browse files
committed
better input fuzz coverage
1 parent 948f3b9 commit 41fc046

1 file changed

Lines changed: 50 additions & 93 deletions

File tree

spark/src/test/scala/org/apache/comet/CometCodegenDispatchFuzzSuite.scala

Lines changed: 50 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,11 @@ import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGener
3636
import org.apache.comet.udf.codegen.CometScalaUDFCodegen
3737

3838
/**
39-
* Randomized tests for the Arrow-direct codegen dispatcher. Schema-driven coverage of every input
40-
* vector class via random parquet files, plus a decimal precision-scale sweep across the
41-
* `Decimal.MAX_LONG_DIGITS=18` boundary at varying null densities.
42-
*
43-
* Extends [[CometTestBase]] (not [[CometFuzzTestBase]]) and inlines the random parquet setup so
44-
* tests run once. The base's three-way cross-product (`shuffle` x `nativeC2R`) does not change
45-
* the codegen path for projection-only queries, so it would be runtime cost without coverage.
39+
* Randomized tests for the Arrow-direct codegen dispatcher: schema-driven coverage of every input
40+
* vector class, plus a decimal precision-scale sweep across the `Decimal.MAX_LONG_DIGITS=18`
41+
* boundary at varying null densities. Extends [[CometTestBase]] (not [[CometFuzzTestBase]])
42+
* because the base's `shuffle` x `nativeC2R` cross-product `test()` override is irrelevant for
43+
* projection-only queries.
4644
*/
4745
class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlanHelper {
4846

@@ -102,6 +100,9 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
102100
1000,
103101
dataGenOptions)
104102
}
103+
104+
spark.read.parquet(mixedTypesFilename).createOrReplaceTempView("t1")
105+
spark.read.parquet(nestedTypesFilename).createOrReplaceTempView("t2")
105106
}
106107

107108
protected override def afterAll(): Unit = {
@@ -112,7 +113,8 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
112113

113114
private val RowCount: Int = 512
114115
private val nullDensities: Seq[Double] = Seq(0.0, 0.1, 0.5, 1.0)
115-
// (precision, scale) pairs spanning both sides of the MAX_LONG_DIGITS=18 boundary.
116+
// (precision, scale) shapes spanning both sides of `Decimal.MAX_LONG_DIGITS=18`: small short,
117+
// boundary short with varying scale, just-past-boundary long, and max decimal128.
116118
private val decimalShapes: Seq[(Int, Int)] = Seq((9, 2), (18, 0), (18, 9), (19, 0), (38, 10))
117119

118120
override protected def sparkConf: SparkConf =
@@ -165,16 +167,12 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
165167
* Identity-Int UDF for the cardinality-based complex probe. One UDF covers every Array and Map
166168
* column, regardless of element type.
167169
*
168-
* Avoiding `Seq[T]` / `Map[K, V]` materialization is deliberate: Spark's
169-
* `org.apache.spark.sql.catalyst.expressions.objects.MapObjects` codegen reads each element via
170-
* `getLong`/`getFloat`/etc. unconditionally and only checks `isNullAt` afterward to decide
171-
* whether to wrap the value in `Option` or null. On null positions of a dictionary-encoded
172-
* primitive Arrow vector the underlying ID buffer holds uninitialized bytes, and
173-
* `decodeToLong/decodeToFloat` against those garbage IDs throws
174-
* `ArrayIndexOutOfBoundsException`. The buggy code is in Spark; the failure reproduces in pure
175-
* Spark execution (no Comet on the trace), so `checkSparkAnswerAndOperator` cannot compute the
176-
* baseline answer. `cardinality(col)` exercises the kernel's `getArray`/`getMap` length read
177-
* while bypassing the element deserializer entirely.
170+
* Avoids `Seq[T]` / `Map[K, V]` UDF arg materialization: Spark's `MapObjects.doGenCode` reads
171+
* each element unconditionally and null-checks afterward, so on null positions of a
172+
* dictionary-encoded primitive Arrow vector the garbage ID buffer feeds
173+
* `dictionary.decodeToLong/decodeToFloat` and throws `ArrayIndexOutOfBoundsException`. Bug
174+
* reproduces in pure Spark; `cardinality(col)` exercises `getArray`/`getMap` without entering
175+
* the element deserializer.
178176
*/
179177
private lazy val cardinalityProbeUdf: String = {
180178
val name = "sz_complex"
@@ -183,9 +181,8 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
183181
}
184182

185183
test("identity ScalaUDF over every primitive column") {
186-
val df = spark.read.parquet(mixedTypesFilename)
187-
df.createOrReplaceTempView("t1")
188-
val primitiveFields = df.schema.fields.filterNot(f => isComplexType(f.dataType))
184+
val primitiveFields =
185+
spark.table("t1").schema.fields.filterNot(f => isComplexType(f.dataType))
189186
assert(primitiveFields.nonEmpty, "expected at least one primitive column in random schema")
190187
for (field <- primitiveFields) {
191188
val udfName = s"id_${field.name}"
@@ -203,36 +200,26 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
203200
}
204201

205202
test("complex-probe ScalaUDF on every complex column") {
206-
val df = spark.read.parquet(mixedTypesFilename)
207-
df.createOrReplaceTempView("t1")
208-
val complexFields = df.schema.fields.filter(f => isComplexType(f.dataType))
203+
val complexFields = spark.table("t1").schema.fields.filter(f => isComplexType(f.dataType))
209204
assert(complexFields.nonEmpty, "expected at least one complex column in random schema")
210205
for (field <- complexFields) {
211206
probeComplexColumn(field, viewName = "t1")
212207
}
213208
}
214209

215210
test("complex-probe ScalaUDF on top-level columns of deeply nested schema") {
216-
val df = spark.read.parquet(nestedTypesFilename)
217-
df.createOrReplaceTempView("t2")
218-
for (field <- df.schema.fields) {
211+
for (field <- spark.table("t2").schema.fields) {
219212
probeComplexColumn(field, viewName = "t2")
220213
}
221214
}
222215

223216
/**
224-
* Element-level fuzz for nested array reads. For every `Array<primitive>` column in the random
225-
* schema, runs `id_X(array_max(col))` so Spark's `ArrayMax.doGenCode` walks every element of
226-
* every row and calls the kernel's nested element getter
227-
* (`getInt`/`getLong`/`getDecimal`/etc.). The cardinality probe deliberately avoids element
228-
* materialization, so without this test no fuzz coverage exists on the element-getter paths the
229-
* unsafe-access optimization would touch. `array_max` is comparison-only on every primitive
230-
* Spark supports, so one expression covers all 14 element types.
217+
* Element-level fuzz for nested array reads: `ArrayMax.doGenCode` walks every element of every
218+
* row, calling the kernel's nested element getter — the path the unsafe-getter optimization
219+
* touches and which the cardinality probe deliberately skips.
231220
*/
232221
test("array_max element fuzz: every Array<primitive> column") {
233-
val df = spark.read.parquet(mixedTypesFilename)
234-
df.createOrReplaceTempView("t1")
235-
val arrayPrimitiveFields = df.schema.fields.filter {
222+
val arrayPrimitiveFields = spark.table("t1").schema.fields.filter {
236223
case StructField(_, ArrayType(elemDt, _), _, _) if !isComplexType(elemDt) => true
237224
case _ => false
238225
}
@@ -256,17 +243,12 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
256243
}
257244

258245
/**
259-
* Element-level fuzz for map key and value reads. `map_keys(col)` / `map_values(col)` produce
260-
* arrays the kernel walks via Spark's `ArrayMax`, exercising the map's child key/value getter.
261-
* The leaf primitive read is structurally the same as in the array element fuzz, but the parent
262-
* offset chain (MapVector -> entries StructVector -> child) differs, so a buggy unsafe getter
263-
* that mishandled the map's per-row offset would slip past the array test alone. Filters to
264-
* top-level `Map<primitive, primitive>` columns from the random nested schema.
246+
* Map variant of the array element fuzz: `map_keys` / `map_values` produce arrays the kernel
247+
* walks via `ArrayMax`, exercising the map's per-row offset chain (MapVector -> entries
248+
* StructVector -> child) that the array test alone wouldn't catch.
265249
*/
266250
test("array_max element fuzz: map_keys / map_values on Map<primitive, primitive> columns") {
267-
val df = spark.read.parquet(nestedTypesFilename)
268-
df.createOrReplaceTempView("t2")
269-
val mapPrimitiveFields = df.schema.fields.filter {
251+
val mapPrimitiveFields = spark.table("t2").schema.fields.filter {
270252
case StructField(_, MapType(kDt, vDt, _), _, _)
271253
if !isComplexType(kDt) && !isComplexType(vDt) =>
272254
true
@@ -288,64 +270,44 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
288270
}
289271
}
290272

273+
private def probeCardinality(accessor: String, viewName: String): Unit = {
274+
assertCodegenRan {
275+
checkSparkAnswerAndOperator(
276+
s"SELECT $cardinalityProbeUdf(cardinality($accessor)) FROM $viewName")
277+
}
278+
}
279+
291280
/**
292-
* Probes one complex top-level column. ArrayType / MapType go through `cardinality(col)` fed to
293-
* the identity-Int probe UDF (see [[cardinalityProbeUdf]] for the rationale). StructType drills
294-
* into each scalar child via `GetStructField` and runs the identity UDF on it; complex children
295-
* are recursed via the same dot-path (depth bounded by the schema generator).
281+
* Top-level Array / Map → cardinality probe. Struct → drill into each scalar child via
282+
* `GetStructField`; nested Array / Map sub-fields also get the cardinality probe (depth bound:
283+
* deeper struct-of-struct nesting is skipped to keep the sweep finite).
296284
*/
297285
private def probeComplexColumn(field: StructField, viewName: String): Unit = {
298286
field.dataType match {
299287
case _: ArrayType | _: MapType =>
300-
assertCodegenRan {
301-
checkSparkAnswerAndOperator(
302-
s"SELECT $cardinalityProbeUdf(cardinality(${field.name})) FROM $viewName")
303-
}
288+
probeCardinality(field.name, viewName)
304289

305290
case st: StructType =>
306291
for (subField <- st.fields) {
307292
val accessor = s"${field.name}.${subField.name}"
308-
if (isComplexType(subField.dataType)) {
309-
probeComplexAccessor(subField, accessor, viewName)
310-
} else {
311-
val udfName = s"id_${field.name}_${subField.name}"
312-
registerIdentityUdfFor(subField.dataType, udfName).foreach { _ =>
313-
assertCodegenRan {
314-
checkSparkAnswerAndOperator(s"SELECT $udfName($accessor) FROM $viewName")
293+
subField.dataType match {
294+
case _: ArrayType | _: MapType => probeCardinality(accessor, viewName)
295+
case dt if !isComplexType(dt) =>
296+
val udfName = s"id_${field.name}_${subField.name}"
297+
registerIdentityUdfFor(dt, udfName).foreach { _ =>
298+
assertCodegenRan {
299+
checkSparkAnswerAndOperator(s"SELECT $udfName($accessor) FROM $viewName")
300+
}
315301
}
316-
}
302+
case _ => // deeper struct nesting skipped
317303
}
318304
}
319305

320-
case _ => // not complex; caller filtered
321-
}
322-
}
323-
324-
/**
325-
* Probes a complex sub-field reached via dot access (e.g. `s.items` for an inner array). The
326-
* dispatcher's bound tree carries `Cardinality(GetStructField(...))` around the kernel's
327-
* complex column read.
328-
*/
329-
private def probeComplexAccessor(
330-
field: StructField,
331-
accessor: String,
332-
viewName: String): Unit = {
333-
field.dataType match {
334-
case _: ArrayType | _: MapType =>
335-
assertCodegenRan {
336-
checkSparkAnswerAndOperator(
337-
s"SELECT $cardinalityProbeUdf(cardinality($accessor)) FROM $viewName")
338-
}
339-
case _ => // deeper struct nesting skipped to keep the sweep bounded
306+
case _ =>
340307
}
341308
}
342309

343-
/**
344-
* Randomized decimal identity UDF. Spans both sides of the `Decimal.MAX_LONG_DIGITS` (18)
345-
* boundary so each test hits one of the two specialized branches in the generated `getDecimal`
346-
* getter. Precisions are chosen to exercise: small short-precision, boundary short-precision
347-
* with varying scale, just-past-boundary long precision, and the max decimal128 precision.
348-
*/
310+
/** Random `BigDecimal` values fitting `(precision, scale)`, with `nullDensity` of them null. */
349311
private def generateDecimals(
350312
seed: Long,
351313
precision: Int,
@@ -389,11 +351,6 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
389351
(precision, scale) <- decimalShapes
390352
} {
391353
test(s"decimal identity precision=$precision scale=$scale nullDensity=$density") {
392-
// Reuse one registered UDF name across iterations; Spark replaces by name. The Scala-side
393-
// signature uses `BigDecimal`, which Spark encodes as DecimalType(38, 18); an implicit Cast
394-
// from the column's DecimalType to the UDF's parameter type runs inside Spark's generated
395-
// code, but the column read still goes through our kernel's `getDecimal` which is the path
396-
// we're fuzzing.
397354
spark.udf.register("dec_id_fuzz", (d: java.math.BigDecimal) => d)
398355
val seed = ((precision * 31L) + scale) * 31L + density.hashCode
399356
val values = generateDecimals(seed, precision, scale, density)

0 commit comments

Comments
 (0)