Skip to content

Commit a0ae768

Browse files
authored
fix(codegen): Use setSafe for fixed-width writes into nested collection children whose element count is data-dependent (#4549)
1 parent dcf3307 commit a0ae768

5 files changed

Lines changed: 449 additions & 7 deletions

File tree

spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,19 +186,32 @@ private[codegen] object CometBatchKernelCodegenOutput {
186186
*
187187
* Scalars emit `perRow` only. Complex types emit both. Inner setup bubbles up so deep child
188188
* casts land at the batch prelude.
189+
*
190+
* `nested` distinguishes the root output vector from a child of a List / Map / Struct.
191+
* `allocateOutput` pre-sizes the root to exactly `numRows` and the kernel writes one scalar per
192+
* row, so the root's fixed-width `set` is always in bounds. A child's element count is instead
193+
* the data-dependent sum of per-row collection sizes, which `numRows` does not bound. We cannot
194+
* pre-size the child either: each row's `ArrayData` / `MapData` is produced by Spark's
195+
* generated `ev.code` inside the write loop, so the total is unknown until we have already
196+
* evaluated every row (counting it first would mean evaluating the tree twice). Nested
197+
* fixed-width writes therefore grow on demand with `setSafe`; the String / Binary / Decimal
198+
* branches already do, for the same reason.
189199
*/
190200
private def emitWrite(
191201
targetVec: String,
192202
idx: String,
193203
source: String,
194204
dataType: DataType,
195-
ctx: CodegenContext): OutputEmit = dataType match {
205+
ctx: CodegenContext,
206+
nested: Boolean = false): OutputEmit = dataType match {
196207
case BooleanType =>
197-
OutputEmit("", s"$targetVec.set($idx, $source ? 1 : 0);")
208+
val set = if (nested) "setSafe" else "set"
209+
OutputEmit("", s"$targetVec.$set($idx, $source ? 1 : 0);")
198210
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | DateType |
199211
TimestampType | TimestampNTZType =>
200212
// Spark codegen emits the matching primitive Java type; Arrow `set` overloads accept it.
201-
OutputEmit("", s"$targetVec.set($idx, $source);")
213+
val set = if (nested) "setSafe" else "set"
214+
OutputEmit("", s"$targetVec.$set($idx, $source);")
202215
case dt: DecimalType =>
203216
// DecimalOutputShortFastPath: precision <= 18 fits in a signed long, so pass the unscaled
204217
// value to `setSafe(int, long)` and skip the BigDecimal allocation.
@@ -250,7 +263,8 @@ private[codegen] object CometBatchKernelCodegenOutput {
250263
val childIdx = ctx.freshName("cidx")
251264
val jVar = ctx.freshName("j")
252265
val elemSource = emitSpecializedGetterExpr(arrVar, jVar, elementType)
253-
val inner = emitWrite(childVar, s"$childIdx + $jVar", elemSource, elementType, ctx)
266+
val inner =
267+
emitWrite(childVar, s"$childIdx + $jVar", elemSource, elementType, ctx, nested = true)
254268
val setup =
255269
(s"$childClass $childVar = ($childClass) $targetVec.getDataVector();" +:
256270
Seq(inner.setup).filter(_.nonEmpty)).mkString("\n")
@@ -285,7 +299,11 @@ private[codegen] object CometBatchKernelCodegenOutput {
285299
val childDecl =
286300
s"$childClass $childVar = ($childClass) $targetVec.getChildByOrdinal($fi);"
287301
val fieldSource = emitSpecializedGetterExpr(rowVar, fi.toString, field.dataType)
288-
val inner = emitWrite(childVar, idx, fieldSource, field.dataType, ctx)
302+
// Struct fields are co-indexed with the struct (written at the same `idx`), so a field is
303+
// nested exactly when the struct is: top-level struct fields land at the row index and are
304+
// pre-sized to numRows (bare `set` is in bounds); a struct nested in an array/map inherits
305+
// that parent's cumulative, unbounded index and needs `setSafe`.
306+
val inner = emitWrite(childVar, idx, fieldSource, field.dataType, ctx, nested = nested)
289307
val write =
290308
if (!field.nullable) {
291309
inner.perRow
@@ -327,8 +345,10 @@ private[codegen] object CometBatchKernelCodegenOutput {
327345
val valClass = outputVectorClass(mt.valueType)
328346
val keySrcExpr = emitSpecializedGetterExpr(keyArr, jVar, mt.keyType)
329347
val valSrcExpr = emitSpecializedGetterExpr(valArr, jVar, mt.valueType)
330-
val keyEmit = emitWrite(keyVar, s"$childIdx + $jVar", keySrcExpr, mt.keyType, ctx)
331-
val valEmit = emitWrite(valVar, s"$childIdx + $jVar", valSrcExpr, mt.valueType, ctx)
348+
val keyEmit =
349+
emitWrite(keyVar, s"$childIdx + $jVar", keySrcExpr, mt.keyType, ctx, nested = true)
350+
val valEmit =
351+
emitWrite(valVar, s"$childIdx + $jVar", valSrcExpr, mt.valueType, ctx, nested = true)
332352
val setup =
333353
(Seq(
334354
s"$structClass $entriesVar = ($structClass) $targetVec.getDataVector();",

spark/src/test/scala/org/apache/comet/CometCodegenAssertions.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
package org.apache.comet
2121

2222
import org.apache.arrow.vector.ValueVector
23+
import org.apache.spark.sql.catalyst.expressions.Expression
2324
import org.apache.spark.sql.types.DataType
2425

26+
import org.apache.comet.codegen.CometBatchKernelCodegen
2527
import org.apache.comet.udf.codegen.CometScalaUDFCodegen
28+
import org.apache.comet.vector.CometVector
2629

2730
/**
2831
* Shared assertions for the codegen-dispatcher test suites. Mix in alongside `CometTestBase`.
@@ -79,4 +82,27 @@ trait CometCodegenAssertions {
7982
s"expected kernel signature $expectedNames -> $output; " +
8083
s"cache had ${sigs.map { case (c, d) => (c.map(_.getSimpleName), d) }}")
8184
}
85+
86+
/**
87+
* Compiles `expr` (no input columns), runs one batch of `numRows`, and hands the output
88+
* `CometVector` to `read`. Every row evaluates to the same value (the expression has no input),
89+
* which still exercises the cross-row cumulative child index of the collection output writer:
90+
* the child of a List / Map grows by each row's element count, so a batch of N rows drives the
91+
* accumulation that a single row cannot. Drives the writer directly, without a query plan, so
92+
* it reaches complex-output expressions the serde does not route through dispatch today. The
93+
* vector is closed after `read` returns, so `read` must fully materialize what it needs.
94+
*/
95+
protected def runKernel[T](expr: Expression, numRows: Int)(read: CometVector => T): T = {
96+
val kernel = CometBatchKernelCodegen.compile(expr, IndexedSeq.empty).newInstance()
97+
val field = CometBatchKernelCodegen.toFfiArrowField("out", expr.dataType, nullable = true)
98+
val out = CometBatchKernelCodegen.allocateOutput(field, numRows, 0)
99+
try {
100+
kernel.init(0)
101+
kernel.process(Array.empty[ValueVector], out, numRows)
102+
out.setValueCount(numRows)
103+
read(CometVector.getVector(out, null))
104+
} finally {
105+
out.close()
106+
}
107+
}
82108
}

spark/src/test/scala/org/apache/comet/CometCodegenFuzzSuite.scala

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,18 @@ import scala.util.Random
2727
import org.apache.commons.io.FileUtils
2828
import org.apache.spark.SparkConf
2929
import org.apache.spark.sql.CometTestBase
30+
import org.apache.spark.sql.catalyst.CatalystTypeConverters
31+
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Literal}
32+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
3033
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3134
import org.apache.spark.sql.internal.SQLConf
3235
import org.apache.spark.sql.types._
36+
import org.apache.spark.unsafe.types.UTF8String
3337

3438
import org.apache.comet.DataTypeSupport.isComplexType
39+
import org.apache.comet.codegen.CometBatchKernelCodegen
3540
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGenerator, SchemaGenOptions}
41+
import org.apache.comet.vector.CometVector
3642

3743
/**
3844
* Randomized end-to-end tests for the Arrow-direct codegen dispatcher: schema-driven coverage of
@@ -406,4 +412,108 @@ class CometCodegenFuzzSuite
406412
}
407413
}
408414
}
415+
416+
/**
417+
* Randomized output-writer coverage (#4539). Generates a random nested output type and a random
418+
* catalyst value of that type, wraps it in a `Literal`, and drives it through the kernel output
419+
* writer with [[runKernel]]. Reading the Arrow output back must reproduce the value.
420+
*
421+
* Random Array / Map sizes mean each collection's child vector fills at a cumulative index that
422+
* `numRows` does not bound, so the writer must grow the child with `setSafe` (the #4539 fix). A
423+
* multi-row batch additionally exercises the cumulative index across rows. The root is always a
424+
* collection so the nested-write path always runs. The generated value is its own oracle:
425+
* `CatalystTypeConverters.convertToScala` materializes both the value and the Arrow readback
426+
* (both expose the catalyst ArrayData / MapData / InternalRow interface) and the two must
427+
* compare equal.
428+
*/
429+
private val outputLeafTypes: Seq[DataType] =
430+
Seq(IntegerType, LongType, DoubleType, BooleanType, StringType, DecimalType(10, 2))
431+
432+
private def randomLeafType(r: Random): DataType =
433+
outputLeafTypes(r.nextInt(outputLeafTypes.size))
434+
435+
/** Random nested type, biased toward leaves as depth runs out. Map keys are always leaves. */
436+
private def randomOutputType(r: Random, depth: Int): DataType =
437+
if (depth <= 0 || r.nextDouble() < 0.4) randomLeafType(r)
438+
else
439+
r.nextInt(3) match {
440+
case 0 => ArrayType(randomOutputType(r, depth - 1), containsNull = true)
441+
case 1 =>
442+
MapType(randomLeafType(r), randomOutputType(r, depth - 1), valueContainsNull = true)
443+
case _ =>
444+
StructType((0 to r.nextInt(2)).map(i =>
445+
StructField(s"f$i", randomOutputType(r, depth - 1), nullable = true)))
446+
}
447+
448+
private def randomLeafValue(r: Random, dt: DataType): Any = dt match {
449+
case IntegerType => r.nextInt()
450+
case LongType => r.nextLong()
451+
case DoubleType => r.nextDouble()
452+
case BooleanType => r.nextBoolean()
453+
case StringType => UTF8String.fromString(s"s${r.nextInt(1000000)}")
454+
case d: DecimalType => Decimal((r.nextInt(2000000) - 1000000).toLong, d.precision, d.scale)
455+
case other => throw new IllegalArgumentException(s"unexpected leaf type $other")
456+
}
457+
458+
/** Random catalyst value of `dt`; `nullable` permits an occasional null element / field. */
459+
private def randomOutputValue(r: Random, dt: DataType, nullable: Boolean): Any = {
460+
if (nullable && r.nextDouble() < 0.2) null
461+
else
462+
dt match {
463+
case ArrayType(e, containsNull) =>
464+
val n = r.nextInt(40)
465+
new GenericArrayData(
466+
(0 until n).map(_ => randomOutputValue(r, e, containsNull)).toArray[Any])
467+
case MapType(k, v, valueContainsNull) =>
468+
// Dedup by materialized key so the map round-trips 1:1 (Spark map keys are distinct).
469+
val entries = scala.collection.mutable.LinkedHashMap.empty[Any, Any]
470+
(0 until r.nextInt(20)).foreach { _ =>
471+
val key = randomOutputValue(r, k, nullable = false)
472+
entries.getOrElseUpdate(key, randomOutputValue(r, v, valueContainsNull))
473+
}
474+
new ArrayBasedMapData(
475+
new GenericArrayData(entries.keys.toArray[Any]),
476+
new GenericArrayData(entries.values.toArray[Any]))
477+
case st: StructType =>
478+
new GenericInternalRow(
479+
st.fields.map(f => randomOutputValue(r, f.dataType, f.nullable)).toArray[Any])
480+
case leaf => randomLeafValue(r, leaf)
481+
}
482+
}
483+
484+
/** Reads the root collection value of `vec` at `row` as a catalyst ArrayData / MapData. */
485+
private def readRoot(vec: CometVector, dt: DataType, row: Int): Any = dt match {
486+
case _: ArrayType => vec.getArray(row)
487+
case _: MapType => vec.getMap(row)
488+
case other => throw new IllegalArgumentException(s"unexpected root type $other")
489+
}
490+
491+
test("randomized dynamically-sized collection output round-trips through the writer (#4539)") {
492+
val r = new Random(42)
493+
val numRows = 4 // > 1 so the child's cumulative index accumulates across rows
494+
// canHandle may reject a generated type (e.g. the maxFields gate on a wide nesting); count
495+
// the ones we actually drove through the writer to guard against a vacuous run.
496+
val exercised = (0 until 300).count { _ =>
497+
// Root is always a collection so the nested-child write path runs every iteration.
498+
val dt =
499+
if (r.nextBoolean()) ArrayType(randomOutputType(r, 2), containsNull = true)
500+
else MapType(randomLeafType(r), randomOutputType(r, 2), valueContainsNull = true)
501+
val value = randomOutputValue(r, dt, nullable = false)
502+
val expr = Literal(value, dt)
503+
val handled = CometBatchKernelCodegen.canHandle(expr).isEmpty
504+
if (handled) {
505+
val expected = CatalystTypeConverters.convertToScala(value, dt)
506+
runKernel(expr, numRows) { vec =>
507+
(0 until numRows).foreach { row =>
508+
val actual = CatalystTypeConverters.convertToScala(readRoot(vec, dt, row), dt)
509+
assert(
510+
actual === expected,
511+
s"row $row mismatch for output type $dt\n expected=$expected\n actual=$actual")
512+
}
513+
}
514+
}
515+
handled
516+
}
517+
assert(exercised > 0, "every generated type was rejected by canHandle (of 300 generated)")
518+
}
409519
}

spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,46 @@ class CometCodegenSourceSuite extends AnyFunSuite {
450450
}
451451
}
452452

453+
test("nested fixed-width map children grow with setSafe, not set (#4539)") {
454+
// Map<Int, Int> output: both key and value are fixed-width children of the entries struct.
455+
// Their element count is the data-dependent sum of per-row map sizes, not bounded by numRows,
456+
// and is unknown until the write loop has evaluated each row, so the writes must use `setSafe`
457+
// to grow on demand. A bare `set` throws once a row's entries exceed the child's initial
458+
// capacity (issue #4539: the literal map's third key overflowed the pre-sized IntVector).
459+
val expr = CreateMap(
460+
Seq(
461+
Literal(1, IntegerType),
462+
Literal(10, IntegerType),
463+
Literal(2, IntegerType),
464+
Literal(20, IntegerType)))
465+
val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq.empty).body
466+
assert(
467+
src.contains(".setSafe("),
468+
s"expected setSafe for nested fixed-width writes; got:\n$src")
469+
// `.set(` is a bare fixed-width write; `setSafe(` / `setNull(` / `setIndexDefined(` do not
470+
// match this literal. There must be none into the nested children.
471+
assert(
472+
!src.contains(".set("),
473+
s"expected no bare fixed-width set into map children; got:\n$src")
474+
}
475+
476+
test("top-level scalar output keeps the pre-sized set fast path") {
477+
// The root output vector is pre-sized to numRows and written once per row, so it uses the
478+
// bare `set` fast path rather than paying for setSafe's per-write capacity check. This pins
479+
// the boundary the #4539 fix draws: setSafe is for nested children only.
480+
val expr = Add(BoundReference(0, IntegerType, nullable = false), Literal(1, IntegerType))
481+
val intSpec = ArrowColumnSpec(
482+
CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"),
483+
nullable = false)
484+
val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(intSpec)).body
485+
assert(
486+
src.contains("output.set("),
487+
s"expected bare set for the pre-sized root output; got:\n$src")
488+
assert(
489+
!src.contains(".setSafe("),
490+
s"expected no setSafe for a scalar root output; got:\n$src")
491+
}
492+
453493
test("ArrayType output elides isNullAt on the element loop when containsNull is false") {
454494
// CreateArray over only-non-null Literals produces ArrayType(elementType, containsNull=false).
455495
// The element write should drop the `arr.isNullAt(j)` guard at source level rather than

0 commit comments

Comments
 (0)