@@ -27,12 +27,18 @@ import scala.util.Random
2727import org .apache .commons .io .FileUtils
2828import org .apache .spark .SparkConf
2929import org .apache .spark .sql .CometTestBase
30+ import org .apache .spark .sql .catalyst .CatalystTypeConverters
31+ import org .apache .spark .sql .catalyst .expressions .{GenericInternalRow , Literal }
32+ import org .apache .spark .sql .catalyst .util .{ArrayBasedMapData , GenericArrayData }
3033import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
3134import org .apache .spark .sql .internal .SQLConf
3235import org .apache .spark .sql .types ._
36+ import org .apache .spark .unsafe .types .UTF8String
3337
3438import org .apache .comet .DataTypeSupport .isComplexType
39+ import org .apache .comet .codegen .CometBatchKernelCodegen
3540import org .apache .comet .testing .{DataGenOptions , FuzzDataGenerator , ParquetGenerator , SchemaGenOptions }
41+ import org .apache .comet .vector .CometVector
3642
3743/**
3844 * Randomized end-to-end tests for the Arrow-direct codegen dispatcher: schema-driven coverage of
@@ -406,4 +412,108 @@ class CometCodegenFuzzSuite
406412 }
407413 }
408414 }
415+
416+ /**
417+ * Randomized output-writer coverage (#4539). Generates a random nested output type and a random
418+ * catalyst value of that type, wraps it in a `Literal`, and drives it through the kernel output
419+ * writer with [[runKernel ]]. Reading the Arrow output back must reproduce the value.
420+ *
421+ * Random Array / Map sizes mean each collection's child vector fills at a cumulative index that
422+ * `numRows` does not bound, so the writer must grow the child with `setSafe` (the #4539 fix). A
423+ * multi-row batch additionally exercises the cumulative index across rows. The root is always a
424+ * collection so the nested-write path always runs. The generated value is its own oracle:
425+ * `CatalystTypeConverters.convertToScala` materializes both the value and the Arrow readback
426+ * (both expose the catalyst ArrayData / MapData / InternalRow interface) and the two must
427+ * compare equal.
428+ */
429+ private val outputLeafTypes : Seq [DataType ] =
430+ Seq (IntegerType , LongType , DoubleType , BooleanType , StringType , DecimalType (10 , 2 ))
431+
432+ private def randomLeafType (r : Random ): DataType =
433+ outputLeafTypes(r.nextInt(outputLeafTypes.size))
434+
435+ /** Random nested type, biased toward leaves as depth runs out. Map keys are always leaves. */
436+ private def randomOutputType (r : Random , depth : Int ): DataType =
437+ if (depth <= 0 || r.nextDouble() < 0.4 ) randomLeafType(r)
438+ else
439+ r.nextInt(3 ) match {
440+ case 0 => ArrayType (randomOutputType(r, depth - 1 ), containsNull = true )
441+ case 1 =>
442+ MapType (randomLeafType(r), randomOutputType(r, depth - 1 ), valueContainsNull = true )
443+ case _ =>
444+ StructType ((0 to r.nextInt(2 )).map(i =>
445+ StructField (s " f $i" , randomOutputType(r, depth - 1 ), nullable = true )))
446+ }
447+
448+ private def randomLeafValue (r : Random , dt : DataType ): Any = dt match {
449+ case IntegerType => r.nextInt()
450+ case LongType => r.nextLong()
451+ case DoubleType => r.nextDouble()
452+ case BooleanType => r.nextBoolean()
453+ case StringType => UTF8String .fromString(s " s ${r.nextInt(1000000 )}" )
454+ case d : DecimalType => Decimal ((r.nextInt(2000000 ) - 1000000 ).toLong, d.precision, d.scale)
455+ case other => throw new IllegalArgumentException (s " unexpected leaf type $other" )
456+ }
457+
458+ /** Random catalyst value of `dt`; `nullable` permits an occasional null element / field. */
459+ private def randomOutputValue (r : Random , dt : DataType , nullable : Boolean ): Any = {
460+ if (nullable && r.nextDouble() < 0.2 ) null
461+ else
462+ dt match {
463+ case ArrayType (e, containsNull) =>
464+ val n = r.nextInt(40 )
465+ new GenericArrayData (
466+ (0 until n).map(_ => randomOutputValue(r, e, containsNull)).toArray[Any ])
467+ case MapType (k, v, valueContainsNull) =>
468+ // Dedup by materialized key so the map round-trips 1:1 (Spark map keys are distinct).
469+ val entries = scala.collection.mutable.LinkedHashMap .empty[Any , Any ]
470+ (0 until r.nextInt(20 )).foreach { _ =>
471+ val key = randomOutputValue(r, k, nullable = false )
472+ entries.getOrElseUpdate(key, randomOutputValue(r, v, valueContainsNull))
473+ }
474+ new ArrayBasedMapData (
475+ new GenericArrayData (entries.keys.toArray[Any ]),
476+ new GenericArrayData (entries.values.toArray[Any ]))
477+ case st : StructType =>
478+ new GenericInternalRow (
479+ st.fields.map(f => randomOutputValue(r, f.dataType, f.nullable)).toArray[Any ])
480+ case leaf => randomLeafValue(r, leaf)
481+ }
482+ }
483+
484+ /** Reads the root collection value of `vec` at `row` as a catalyst ArrayData / MapData. */
485+ private def readRoot (vec : CometVector , dt : DataType , row : Int ): Any = dt match {
486+ case _ : ArrayType => vec.getArray(row)
487+ case _ : MapType => vec.getMap(row)
488+ case other => throw new IllegalArgumentException (s " unexpected root type $other" )
489+ }
490+
491+ test(" randomized dynamically-sized collection output round-trips through the writer (#4539)" ) {
492+ val r = new Random (42 )
493+ val numRows = 4 // > 1 so the child's cumulative index accumulates across rows
494+ // canHandle may reject a generated type (e.g. the maxFields gate on a wide nesting); count
495+ // the ones we actually drove through the writer to guard against a vacuous run.
496+ val exercised = (0 until 300 ).count { _ =>
497+ // Root is always a collection so the nested-child write path runs every iteration.
498+ val dt =
499+ if (r.nextBoolean()) ArrayType (randomOutputType(r, 2 ), containsNull = true )
500+ else MapType (randomLeafType(r), randomOutputType(r, 2 ), valueContainsNull = true )
501+ val value = randomOutputValue(r, dt, nullable = false )
502+ val expr = Literal (value, dt)
503+ val handled = CometBatchKernelCodegen .canHandle(expr).isEmpty
504+ if (handled) {
505+ val expected = CatalystTypeConverters .convertToScala(value, dt)
506+ runKernel(expr, numRows) { vec =>
507+ (0 until numRows).foreach { row =>
508+ val actual = CatalystTypeConverters .convertToScala(readRoot(vec, dt, row), dt)
509+ assert(
510+ actual === expected,
511+ s " row $row mismatch for output type $dt\n expected= $expected\n actual= $actual" )
512+ }
513+ }
514+ }
515+ handled
516+ }
517+ assert(exercised > 0 , " every generated type was rejected by canHandle (of 300 generated)" )
518+ }
409519}
0 commit comments