Skip to content

Commit 965c2ba

Browse files
committed
better input fuzz coverage
1 parent 6643208 commit 965c2ba

1 file changed

Lines changed: 218 additions & 3 deletions

File tree

spark/src/test/scala/org/apache/comet/CometCodegenDispatchFuzzSuite.scala

Lines changed: 218 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,97 @@
1919

2020
package org.apache.comet
2121

22+
import java.io.File
23+
import java.text.SimpleDateFormat
24+
2225
import scala.util.Random
2326

27+
import org.apache.commons.io.FileUtils
2428
import org.apache.spark.SparkConf
2529
import org.apache.spark.sql.CometTestBase
2630
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
31+
import org.apache.spark.sql.internal.SQLConf
32+
import org.apache.spark.sql.types._
2733

34+
import org.apache.comet.DataTypeSupport.isComplexType
35+
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGenerator, SchemaGenOptions}
2836
import org.apache.comet.udf.codegen.CometScalaUDFCodegen
2937

3038
/**
31-
* Randomized tests for the Arrow-direct codegen dispatcher. Generates inputs at varying null
32-
* densities and runs them through ScalaUDFs that route through the dispatcher, asserting Comet
33-
* results agree with Spark. Fixes a seed per test for reproducibility.
39+
* Randomized tests for the Arrow-direct codegen dispatcher. Schema-driven coverage of every input
40+
* vector class via random parquet files, plus a decimal precision-scale sweep across the
41+
* `Decimal.MAX_LONG_DIGITS=18` boundary at varying null densities.
42+
*
43+
* Extends [[CometTestBase]] (not [[CometFuzzTestBase]]) and inlines the random parquet setup so
44+
* tests run once. The base's three-way cross-product (`shuffle` x `nativeC2R`) does not change
45+
* the codegen path for projection-only queries, so it would be runtime cost without coverage.
3446
*/
3547
class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlanHelper {
3648

49+
/** Random schema with primitives plus shallow arrays and structs. No maps, no deep nesting. */
50+
private var mixedTypesFilename: String = _
51+
52+
/** Random schema with deeply nested arrays / structs / maps. */
53+
private var nestedTypesFilename: String = _
54+
55+
/** Asia/Kathmandu has a non-zero minute offset (UTC+5:45); good for timezone edge cases. */
56+
private val defaultTimezone = "Asia/Kathmandu"
57+
58+
override def beforeAll(): Unit = {
59+
super.beforeAll()
60+
val tempDir = System.getProperty("java.io.tmpdir")
61+
val random = new Random(42)
62+
val dataGenOptions = DataGenOptions(
63+
generateNegativeZero = false,
64+
baseDate = new SimpleDateFormat("YYYY-MM-DD hh:mm:ss")
65+
.parse("2024-05-25 12:34:56")
66+
.getTime)
67+
68+
mixedTypesFilename =
69+
s"$tempDir/CometCodegenDispatchFuzzSuite_${System.currentTimeMillis()}.parquet"
70+
withSQLConf(
71+
CometConf.COMET_ENABLED.key -> "false",
72+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) {
73+
val schemaGenOptions =
74+
SchemaGenOptions(generateArray = true, generateStruct = true)
75+
ParquetGenerator.makeParquetFile(
76+
random,
77+
spark,
78+
mixedTypesFilename,
79+
1000,
80+
schemaGenOptions,
81+
dataGenOptions)
82+
}
83+
84+
nestedTypesFilename =
85+
s"$tempDir/CometCodegenDispatchFuzzSuite_nested_${System.currentTimeMillis()}.parquet"
86+
withSQLConf(
87+
CometConf.COMET_ENABLED.key -> "false",
88+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) {
89+
val schemaGenOptions =
90+
SchemaGenOptions(generateArray = true, generateStruct = true, generateMap = true)
91+
val schema = FuzzDataGenerator.generateNestedSchema(
92+
random,
93+
numCols = 10,
94+
minDepth = 2,
95+
maxDepth = 4,
96+
options = schemaGenOptions)
97+
ParquetGenerator.makeParquetFile(
98+
random,
99+
spark,
100+
nestedTypesFilename,
101+
schema,
102+
1000,
103+
dataGenOptions)
104+
}
105+
}
106+
107+
protected override def afterAll(): Unit = {
108+
super.afterAll()
109+
FileUtils.deleteDirectory(new File(mixedTypesFilename))
110+
FileUtils.deleteDirectory(new File(nestedTypesFilename))
111+
}
112+
37113
private val RowCount: Int = 512
38114
private val nullDensities: Seq[Double] = Seq(0.0, 0.1, 0.5, 1.0)
39115
// (precision, scale) pairs spanning both sides of the MAX_LONG_DIGITS=18 boundary.
@@ -57,6 +133,145 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
57133
s"expected at least one codegen dispatcher invocation during this query, got $after")
58134
}
59135

136+
/**
137+
* Identity ScalaUDF for one of the 14 primitive types in
138+
* [[org.apache.comet.testing.SchemaGenOptions.defaultPrimitiveTypes]]. Returns the registered
139+
* name when the type maps to a known Scala arg, or `None` for shapes we choose not to probe.
140+
* `BigDecimal` UDF args are encoded as `DecimalType(38, 18)`; Spark inserts an implicit cast
141+
* around the call but the underlying column read still hits our kernel's `getDecimal` at the
142+
* column's native precision.
143+
*/
144+
private def registerIdentityUdfFor(dt: DataType, name: String): Option[String] = dt match {
145+
case _: BooleanType => spark.udf.register(name, (x: Boolean) => x); Some(name)
146+
case _: ByteType => spark.udf.register(name, (x: Byte) => x); Some(name)
147+
case _: ShortType => spark.udf.register(name, (x: Short) => x); Some(name)
148+
case _: IntegerType => spark.udf.register(name, (x: Int) => x); Some(name)
149+
case _: LongType => spark.udf.register(name, (x: Long) => x); Some(name)
150+
case _: FloatType => spark.udf.register(name, (x: Float) => x); Some(name)
151+
case _: DoubleType => spark.udf.register(name, (x: Double) => x); Some(name)
152+
case _: DecimalType =>
153+
spark.udf.register(name, (x: java.math.BigDecimal) => x); Some(name)
154+
case _: DateType => spark.udf.register(name, (x: java.sql.Date) => x); Some(name)
155+
case _: TimestampType =>
156+
spark.udf.register(name, (x: java.sql.Timestamp) => x); Some(name)
157+
case _: TimestampNTZType =>
158+
spark.udf.register(name, (x: java.time.LocalDateTime) => x); Some(name)
159+
case _: StringType => spark.udf.register(name, (x: String) => x); Some(name)
160+
case _: BinaryType => spark.udf.register(name, (x: Array[Byte]) => x); Some(name)
161+
case _ => None
162+
}
163+
164+
/**
165+
* Identity-Int UDF for the cardinality-based complex probe. One UDF covers every Array and Map
166+
* column, regardless of element type.
167+
*
168+
* Avoiding `Seq[T]` / `Map[K, V]` materialization is deliberate: Spark's
169+
* `org.apache.spark.sql.catalyst.expressions.objects.MapObjects` codegen reads each element via
170+
* `getLong`/`getFloat`/etc. unconditionally and only checks `isNullAt` afterward to decide
171+
* whether to wrap the value in `Option` or null. On null positions of a dictionary-encoded
172+
* primitive Arrow vector the underlying ID buffer holds uninitialized bytes, and
173+
* `decodeToLong/decodeToFloat` against those garbage IDs throws
174+
* `ArrayIndexOutOfBoundsException`. The buggy code is in Spark; the failure reproduces in pure
175+
* Spark execution (no Comet on the trace), so `checkSparkAnswerAndOperator` cannot compute the
176+
* baseline answer. `cardinality(col)` exercises the kernel's `getArray`/`getMap` length read
177+
* while bypassing the element deserializer entirely.
178+
*/
179+
private lazy val cardinalityProbeUdf: String = {
180+
val name = "sz_complex"
181+
spark.udf.register(name, (i: Int) => i)
182+
name
183+
}
184+
185+
test("identity ScalaUDF over every primitive column") {
186+
val df = spark.read.parquet(mixedTypesFilename)
187+
df.createOrReplaceTempView("t1")
188+
val primitiveFields = df.schema.fields.filterNot(f => isComplexType(f.dataType))
189+
assert(primitiveFields.nonEmpty, "expected at least one primitive column in random schema")
190+
for (field <- primitiveFields) {
191+
val udfName = s"id_${field.name}"
192+
registerIdentityUdfFor(field.dataType, udfName) match {
193+
case Some(_) =>
194+
assertCodegenRan {
195+
checkSparkAnswerAndOperator(s"SELECT $udfName(${field.name}) FROM t1")
196+
}
197+
case None =>
198+
fail(
199+
s"primitive column ${field.name}: ${field.dataType} not in identity UDF catalog; " +
200+
"extend registerIdentityUdfFor")
201+
}
202+
}
203+
}
204+
205+
test("complex-probe ScalaUDF on every complex column") {
206+
val df = spark.read.parquet(mixedTypesFilename)
207+
df.createOrReplaceTempView("t1")
208+
val complexFields = df.schema.fields.filter(f => isComplexType(f.dataType))
209+
assert(complexFields.nonEmpty, "expected at least one complex column in random schema")
210+
for (field <- complexFields) {
211+
probeComplexColumn(field, viewName = "t1")
212+
}
213+
}
214+
215+
test("complex-probe ScalaUDF on top-level columns of deeply nested schema") {
216+
val df = spark.read.parquet(nestedTypesFilename)
217+
df.createOrReplaceTempView("t2")
218+
for (field <- df.schema.fields) {
219+
probeComplexColumn(field, viewName = "t2")
220+
}
221+
}
222+
223+
/**
224+
* Probes one complex top-level column. ArrayType / MapType go through `cardinality(col)` fed to
225+
* the identity-Int probe UDF (see [[cardinalityProbeUdf]] for the rationale). StructType drills
226+
* into each scalar child via `GetStructField` and runs the identity UDF on it; complex children
227+
* are recursed via the same dot-path (depth bounded by the schema generator).
228+
*/
229+
private def probeComplexColumn(field: StructField, viewName: String): Unit = {
230+
field.dataType match {
231+
case _: ArrayType | _: MapType =>
232+
assertCodegenRan {
233+
checkSparkAnswerAndOperator(
234+
s"SELECT $cardinalityProbeUdf(cardinality(${field.name})) FROM $viewName")
235+
}
236+
237+
case st: StructType =>
238+
for (subField <- st.fields) {
239+
val accessor = s"${field.name}.${subField.name}"
240+
if (isComplexType(subField.dataType)) {
241+
probeComplexAccessor(subField, accessor, viewName)
242+
} else {
243+
val udfName = s"id_${field.name}_${subField.name}"
244+
registerIdentityUdfFor(subField.dataType, udfName).foreach { _ =>
245+
assertCodegenRan {
246+
checkSparkAnswerAndOperator(s"SELECT $udfName($accessor) FROM $viewName")
247+
}
248+
}
249+
}
250+
}
251+
252+
case _ => // not complex; caller filtered
253+
}
254+
}
255+
256+
/**
257+
* Probes a complex sub-field reached via dot access (e.g. `s.items` for an inner array). The
258+
* dispatcher's bound tree carries `Cardinality(GetStructField(...))` around the kernel's
259+
* complex column read.
260+
*/
261+
private def probeComplexAccessor(
262+
field: StructField,
263+
accessor: String,
264+
viewName: String): Unit = {
265+
field.dataType match {
266+
case _: ArrayType | _: MapType =>
267+
assertCodegenRan {
268+
checkSparkAnswerAndOperator(
269+
s"SELECT $cardinalityProbeUdf(cardinality($accessor)) FROM $viewName")
270+
}
271+
case _ => // deeper struct nesting skipped to keep the sweep bounded
272+
}
273+
}
274+
60275
/**
61276
* Randomized decimal identity UDF. Spans both sides of the `Decimal.MAX_LONG_DIGITS` (18)
62277
* boundary so each test hits one of the two specialized branches in the generated `getDecimal`

0 commit comments

Comments
 (0)