Skip to content

Commit d2fbd6e

Browse files
authored
fix: fix memory safety issue in native c2r (#3367)
1 parent 1babf36 commit d2fbd6e

2 files changed

Lines changed: 66 additions & 2 deletions

File tree

spark/src/main/scala/org/apache/comet/NativeColumnarToRowConverter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,6 @@ private class NativeRowIterator(info: NativeColumnarToRowInfo, unsafeRow: Unsafe
139139
unsafeRow.pointTo(null, rowAddress, rowSize)
140140
currentIdx += 1
141141

142-
unsafeRow
142+
unsafeRow.copy()
143143
}
144144
}

spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.comet.CometNativeColumnarToRowExec
3232
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3333
import org.apache.spark.sql.types._
3434

35-
import org.apache.comet.CometConf
35+
import org.apache.comet.{CometConf, NativeColumnarToRowConverter}
3636
import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, SchemaGenOptions}
3737

3838
/**
@@ -470,6 +470,70 @@ class CometNativeColumnarToRowSuite extends CometTestBase with AdaptiveSparkPlan
470470
}
471471
}
472472

473+
// Regression test for https://github.com/apache/datafusion-comet/issues/3308
474+
// Native columnar-to-row returns UnsafeRow pointing into a Rust-owned buffer that is
475+
// cleared/reused on each convert() call. This test directly exercises the converter:
476+
// it converts multiple batches and holds row references from earlier batches, then
477+
// verifies they still contain correct data. Without a fix (e.g., copying rows),
478+
// rows from earlier batches will contain corrupted data from buffer reuse.
479+
test("rows from earlier batches are not corrupted by subsequent convert() calls") {
480+
import org.apache.spark.sql.catalyst.InternalRow
481+
import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters
482+
import org.apache.spark.unsafe.types.UTF8String
483+
484+
import scala.collection.mutable.ArrayBuffer
485+
486+
val schema = new StructType().add("id", IntegerType).add("str", StringType)
487+
488+
// Create multiple small batches using CometArrowConverters
489+
val numBatches = 10
490+
val rowsPerBatch = 5
491+
val totalRows = numBatches * rowsPerBatch
492+
493+
val rows = (0 until totalRows).map { i =>
494+
InternalRow(i, UTF8String.fromString(s"value_$i"))
495+
}
496+
497+
// Create batches using rowToArrowBatchIter which handles shading internally
498+
val batchIter = CometArrowConverters
499+
.rowToArrowBatchIter(rows.iterator, schema, rowsPerBatch, "UTC", null)
500+
501+
val converter = new NativeColumnarToRowConverter(schema, rowsPerBatch)
502+
try {
503+
// Collect all rows from all batches into a single array
504+
// The converter returns rows that should be independent copies
505+
val allRows = new ArrayBuffer[InternalRow]()
506+
var batchCount = 0
507+
508+
while (batchIter.hasNext) {
509+
val batch = batchIter.next()
510+
batchCount += 1
511+
// Convert this batch and collect all rows
512+
val rowIter = converter.convert(batch)
513+
while (rowIter.hasNext) {
514+
allRows += rowIter.next()
515+
}
516+
batch.close()
517+
}
518+
519+
assert(batchCount == numBatches, s"Expected $numBatches batches, got $batchCount")
520+
assert(allRows.length == totalRows, s"Expected $totalRows rows, got ${allRows.length}")
521+
522+
// Verify that reading through held references produces all expected
523+
// distinct values. If rows weren't copied, all entries would point
524+
// to the same reused UnsafeRow object with stale data.
525+
val distinctIds = allRows.map(_.getInt(0)).toSet
526+
assert(
527+
distinctIds.size == totalRows,
528+
s"UnsafeRow reuse bug: expected $totalRows distinct row IDs but got " +
529+
s"${distinctIds.size} (values: ${distinctIds.toSeq.sorted.mkString(", ")}). " +
530+
"This means rows were not copied and all references point to the same " +
531+
"reused UnsafeRow object.")
532+
} finally {
533+
converter.close()
534+
}
535+
}
536+
473537
/**
474538
* Helper to create a parquet table from a DataFrame and run a function with it.
475539
*/

0 commit comments

Comments
 (0)