1919
2020package org .apache .comet
2121
22+ import java .io .File
23+ import java .text .SimpleDateFormat
24+
2225import scala .util .Random
2326
27+ import org .apache .commons .io .FileUtils
2428import org .apache .spark .SparkConf
2529import org .apache .spark .sql .CometTestBase
2630import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
31+ import org .apache .spark .sql .internal .SQLConf
32+ import org .apache .spark .sql .types ._
2733
34+ import org .apache .comet .DataTypeSupport .isComplexType
35+ import org .apache .comet .testing .{DataGenOptions , FuzzDataGenerator , ParquetGenerator , SchemaGenOptions }
2836import org .apache .comet .udf .codegen .CometScalaUDFCodegen
2937
3038/**
31- * Randomized tests for the Arrow-direct codegen dispatcher. Generates inputs at varying null
32- * densities and runs them through ScalaUDFs that route through the dispatcher, asserting Comet
33- * results agree with Spark. Fixes a seed per test for reproducibility.
39+ * Randomized tests for the Arrow-direct codegen dispatcher. Schema-driven coverage of every input
40+ * vector class via random parquet files, plus a decimal precision-scale sweep across the
41+ * `Decimal.MAX_LONG_DIGITS=18` boundary at varying null densities.
42+ *
43+ * Extends [[CometTestBase ]] (not [[CometFuzzTestBase ]]) and inlines the random parquet setup so
44+ * tests run once. The base's three-way cross-product (`shuffle` x `nativeC2R`) does not change
45+ * the codegen path for projection-only queries, so it would be runtime cost without coverage.
3446 */
3547class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlanHelper {
3648
49+ /** Random schema with primitives plus shallow arrays and structs. No maps, no deep nesting. */
50+ private var mixedTypesFilename : String = _
51+
52+ /** Random schema with deeply nested arrays / structs / maps. */
53+ private var nestedTypesFilename : String = _
54+
55+ /** Asia/Kathmandu has a non-zero minute offset (UTC+5:45); good for timezone edge cases. */
56+ private val defaultTimezone = " Asia/Kathmandu"
57+
58+ override def beforeAll (): Unit = {
59+ super .beforeAll()
60+ val tempDir = System .getProperty(" java.io.tmpdir" )
61+ val random = new Random (42 )
62+ val dataGenOptions = DataGenOptions (
63+ generateNegativeZero = false ,
64+ baseDate = new SimpleDateFormat (" YYYY-MM-DD hh:mm:ss" )
65+ .parse(" 2024-05-25 12:34:56" )
66+ .getTime)
67+
68+ mixedTypesFilename =
69+ s " $tempDir/CometCodegenDispatchFuzzSuite_ ${System .currentTimeMillis()}.parquet "
70+ withSQLConf(
71+ CometConf .COMET_ENABLED .key -> " false" ,
72+ SQLConf .SESSION_LOCAL_TIMEZONE .key -> defaultTimezone) {
73+ val schemaGenOptions =
74+ SchemaGenOptions (generateArray = true , generateStruct = true )
75+ ParquetGenerator .makeParquetFile(
76+ random,
77+ spark,
78+ mixedTypesFilename,
79+ 1000 ,
80+ schemaGenOptions,
81+ dataGenOptions)
82+ }
83+
84+ nestedTypesFilename =
85+ s " $tempDir/CometCodegenDispatchFuzzSuite_nested_ ${System .currentTimeMillis()}.parquet "
86+ withSQLConf(
87+ CometConf .COMET_ENABLED .key -> " false" ,
88+ SQLConf .SESSION_LOCAL_TIMEZONE .key -> defaultTimezone) {
89+ val schemaGenOptions =
90+ SchemaGenOptions (generateArray = true , generateStruct = true , generateMap = true )
91+ val schema = FuzzDataGenerator .generateNestedSchema(
92+ random,
93+ numCols = 10 ,
94+ minDepth = 2 ,
95+ maxDepth = 4 ,
96+ options = schemaGenOptions)
97+ ParquetGenerator .makeParquetFile(
98+ random,
99+ spark,
100+ nestedTypesFilename,
101+ schema,
102+ 1000 ,
103+ dataGenOptions)
104+ }
105+ }
106+
107+ protected override def afterAll (): Unit = {
108+ super .afterAll()
109+ FileUtils .deleteDirectory(new File (mixedTypesFilename))
110+ FileUtils .deleteDirectory(new File (nestedTypesFilename))
111+ }
112+
37113 private val RowCount : Int = 512
38114 private val nullDensities : Seq [Double ] = Seq (0.0 , 0.1 , 0.5 , 1.0 )
39115 // (precision, scale) pairs spanning both sides of the MAX_LONG_DIGITS=18 boundary.
@@ -57,6 +133,145 @@ class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlan
57133 s " expected at least one codegen dispatcher invocation during this query, got $after" )
58134 }
59135
136+ /**
137+ * Identity ScalaUDF for one of the 14 primitive types in
138+ * [[org.apache.comet.testing.SchemaGenOptions.defaultPrimitiveTypes ]]. Returns the registered
139+ * name when the type maps to a known Scala arg, or `None` for shapes we choose not to probe.
140+ * `BigDecimal` UDF args are encoded as `DecimalType(38, 18)`; Spark inserts an implicit cast
141+ * around the call but the underlying column read still hits our kernel's `getDecimal` at the
142+ * column's native precision.
143+ */
144+ private def registerIdentityUdfFor (dt : DataType , name : String ): Option [String ] = dt match {
145+ case _ : BooleanType => spark.udf.register(name, (x : Boolean ) => x); Some (name)
146+ case _ : ByteType => spark.udf.register(name, (x : Byte ) => x); Some (name)
147+ case _ : ShortType => spark.udf.register(name, (x : Short ) => x); Some (name)
148+ case _ : IntegerType => spark.udf.register(name, (x : Int ) => x); Some (name)
149+ case _ : LongType => spark.udf.register(name, (x : Long ) => x); Some (name)
150+ case _ : FloatType => spark.udf.register(name, (x : Float ) => x); Some (name)
151+ case _ : DoubleType => spark.udf.register(name, (x : Double ) => x); Some (name)
152+ case _ : DecimalType =>
153+ spark.udf.register(name, (x : java.math.BigDecimal ) => x); Some (name)
154+ case _ : DateType => spark.udf.register(name, (x : java.sql.Date ) => x); Some (name)
155+ case _ : TimestampType =>
156+ spark.udf.register(name, (x : java.sql.Timestamp ) => x); Some (name)
157+ case _ : TimestampNTZType =>
158+ spark.udf.register(name, (x : java.time.LocalDateTime ) => x); Some (name)
159+ case _ : StringType => spark.udf.register(name, (x : String ) => x); Some (name)
160+ case _ : BinaryType => spark.udf.register(name, (x : Array [Byte ]) => x); Some (name)
161+ case _ => None
162+ }
163+
164+ /**
165+ * Identity-Int UDF for the cardinality-based complex probe. One UDF covers every Array and Map
166+ * column, regardless of element type.
167+ *
168+ * Avoiding `Seq[T]` / `Map[K, V]` materialization is deliberate: Spark's
169+ * `org.apache.spark.sql.catalyst.expressions.objects.MapObjects` codegen reads each element via
170+ * `getLong`/`getFloat`/etc. unconditionally and only checks `isNullAt` afterward to decide
171+ * whether to wrap the value in `Option` or null. On null positions of a dictionary-encoded
172+ * primitive Arrow vector the underlying ID buffer holds uninitialized bytes, and
173+ * `decodeToLong/decodeToFloat` against those garbage IDs throws
174+ * `ArrayIndexOutOfBoundsException`. The buggy code is in Spark; the failure reproduces in pure
175+ * Spark execution (no Comet on the trace), so `checkSparkAnswerAndOperator` cannot compute the
176+ * baseline answer. `cardinality(col)` exercises the kernel's `getArray`/`getMap` length read
177+ * while bypassing the element deserializer entirely.
178+ */
179+ private lazy val cardinalityProbeUdf : String = {
180+ val name = " sz_complex"
181+ spark.udf.register(name, (i : Int ) => i)
182+ name
183+ }
184+
185+ test(" identity ScalaUDF over every primitive column" ) {
186+ val df = spark.read.parquet(mixedTypesFilename)
187+ df.createOrReplaceTempView(" t1" )
188+ val primitiveFields = df.schema.fields.filterNot(f => isComplexType(f.dataType))
189+ assert(primitiveFields.nonEmpty, " expected at least one primitive column in random schema" )
190+ for (field <- primitiveFields) {
191+ val udfName = s " id_ ${field.name}"
192+ registerIdentityUdfFor(field.dataType, udfName) match {
193+ case Some (_) =>
194+ assertCodegenRan {
195+ checkSparkAnswerAndOperator(s " SELECT $udfName( ${field.name}) FROM t1 " )
196+ }
197+ case None =>
198+ fail(
199+ s " primitive column ${field.name}: ${field.dataType} not in identity UDF catalog; " +
200+ " extend registerIdentityUdfFor" )
201+ }
202+ }
203+ }
204+
205+ test(" complex-probe ScalaUDF on every complex column" ) {
206+ val df = spark.read.parquet(mixedTypesFilename)
207+ df.createOrReplaceTempView(" t1" )
208+ val complexFields = df.schema.fields.filter(f => isComplexType(f.dataType))
209+ assert(complexFields.nonEmpty, " expected at least one complex column in random schema" )
210+ for (field <- complexFields) {
211+ probeComplexColumn(field, viewName = " t1" )
212+ }
213+ }
214+
215+ test(" complex-probe ScalaUDF on top-level columns of deeply nested schema" ) {
216+ val df = spark.read.parquet(nestedTypesFilename)
217+ df.createOrReplaceTempView(" t2" )
218+ for (field <- df.schema.fields) {
219+ probeComplexColumn(field, viewName = " t2" )
220+ }
221+ }
222+
223+ /**
224+ * Probes one complex top-level column. ArrayType / MapType go through `cardinality(col)` fed to
225+ * the identity-Int probe UDF (see [[cardinalityProbeUdf ]] for the rationale). StructType drills
226+ * into each scalar child via `GetStructField` and runs the identity UDF on it; complex children
227+ * are recursed via the same dot-path (depth bounded by the schema generator).
228+ */
229+ private def probeComplexColumn (field : StructField , viewName : String ): Unit = {
230+ field.dataType match {
231+ case _ : ArrayType | _ : MapType =>
232+ assertCodegenRan {
233+ checkSparkAnswerAndOperator(
234+ s " SELECT $cardinalityProbeUdf(cardinality( ${field.name})) FROM $viewName" )
235+ }
236+
237+ case st : StructType =>
238+ for (subField <- st.fields) {
239+ val accessor = s " ${field.name}. ${subField.name}"
240+ if (isComplexType(subField.dataType)) {
241+ probeComplexAccessor(subField, accessor, viewName)
242+ } else {
243+ val udfName = s " id_ ${field.name}_ ${subField.name}"
244+ registerIdentityUdfFor(subField.dataType, udfName).foreach { _ =>
245+ assertCodegenRan {
246+ checkSparkAnswerAndOperator(s " SELECT $udfName( $accessor) FROM $viewName" )
247+ }
248+ }
249+ }
250+ }
251+
252+ case _ => // not complex; caller filtered
253+ }
254+ }
255+
256+ /**
257+ * Probes a complex sub-field reached via dot access (e.g. `s.items` for an inner array). The
258+ * dispatcher's bound tree carries `Cardinality(GetStructField(...))` around the kernel's
259+ * complex column read.
260+ */
261+ private def probeComplexAccessor (
262+ field : StructField ,
263+ accessor : String ,
264+ viewName : String ): Unit = {
265+ field.dataType match {
266+ case _ : ArrayType | _ : MapType =>
267+ assertCodegenRan {
268+ checkSparkAnswerAndOperator(
269+ s " SELECT $cardinalityProbeUdf(cardinality( $accessor)) FROM $viewName" )
270+ }
271+ case _ => // deeper struct nesting skipped to keep the sweep bounded
272+ }
273+ }
274+
60275 /**
61276 * Randomized decimal identity UDF. Spans both sides of the `Decimal.MAX_LONG_DIGITS` (18)
62277 * boundary so each test hits one of the two specialized branches in the generated `getDecimal`
0 commit comments