Skip to content

Commit a86c599

Browse files
committed
Merge branch 'main' into db/split-lint
2 parents 366197e + 71f22f5 commit a86c599

11 files changed

Lines changed: 556 additions & 66 deletions

File tree

common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ import java.nio.channels.Channels
2626
import scala.jdk.CollectionConverters._
2727

2828
import org.apache.arrow.c.CDataDictionaryProvider
29-
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
29+
import org.apache.arrow.vector._
3030
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
3131
import org.apache.arrow.vector.dictionary.DictionaryProvider
32-
import org.apache.arrow.vector.ipc.ArrowStreamWriter
32+
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
3333
import org.apache.arrow.vector.types._
3434
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
35+
import org.apache.arrow.vector.util.VectorSchemaRootAppender
3536
import org.apache.spark.{SparkEnv, SparkException}
37+
import org.apache.spark.internal.Logging
3638
import org.apache.spark.io.CompressionCodec
3739
import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
3840
import org.apache.spark.sql.types._
@@ -43,7 +45,7 @@ import org.apache.comet.Constants.COMET_CONF_DIR_ENV
4345
import org.apache.comet.shims.CometTypeShim
4446
import org.apache.comet.vector.CometVector
4547

46-
object Utils extends CometTypeShim {
48+
object Utils extends CometTypeShim with Logging {
4749
def getConfPath(confFileName: String): String = {
4850
sys.env
4951
.get(COMET_CONF_DIR_ENV)
@@ -232,6 +234,7 @@ object Utils extends CometTypeShim {
232234

233235
/**
234236
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
237+
*
235238
* @param bytes
236239
* the serialized batches
237240
* @param source
@@ -252,6 +255,108 @@ object Utils extends CometTypeShim {
252255
new ArrowReaderIterator(Channels.newChannel(ins), source)
253256
}
254257

258+
/**
259+
* Coalesces many small Arrow IPC batches into a single batch for broadcasting.
260+
*
261+
* Why this is necessary: The broadcast exchange collects shuffle output by calling
262+
* getByteArrayRdd, which serializes each ColumnarBatch independently into its own
263+
* ChunkedByteBuffer. The shuffle reader (CometBlockStoreShuffleReader) produces one
264+
* ColumnarBatch per shuffle block, and there is one block per writer task per output partition.
265+
* So with W writer tasks and P output partitions, the broadcast collects up to W * P tiny
266+
* batches. For example, with 400 writer tasks and 500 partitions, 1M rows would arrive as ~200K
267+
* batches of ~5 rows each.
268+
*
269+
* Without coalescing, every consumer task in the broadcast join would independently deserialize
270+
* all of these tiny Arrow IPC streams, paying per-stream overhead (schema parsing, buffer
271+
* allocation) for each one. With coalescing, we decode and append all batches into one
272+
* VectorSchemaRoot on the driver, then re-serialize once. Each consumer task then deserializes
273+
* a single Arrow IPC stream.
274+
*/
275+
def coalesceBroadcastBatches(
276+
input: Iterator[ChunkedByteBuffer]): (Array[ChunkedByteBuffer], Long, Long) = {
277+
val buffers = input.filterNot(_.size == 0).toArray
278+
if (buffers.isEmpty) {
279+
return (Array.empty, 0L, 0L)
280+
}
281+
282+
val allocator = org.apache.comet.CometArrowAllocator
283+
.newChildAllocator("broadcast-coalesce", 0, Long.MaxValue)
284+
try {
285+
var targetRoot: VectorSchemaRoot = null
286+
var totalRows = 0L
287+
var batchCount = 0
288+
289+
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
290+
try {
291+
for (bytes <- buffers) {
292+
val compressedInputStream =
293+
new DataInputStream(codec.compressedInputStream(bytes.toInputStream()))
294+
val reader =
295+
new ArrowStreamReader(Channels.newChannel(compressedInputStream), allocator)
296+
try {
297+
// Comet decodes dictionaries during execution, so this shouldn't happen.
298+
// If it does, fall back to the original uncoalesced buffers because each
299+
// partition can have a different dictionary, and appending index vectors
300+
// would silently mix indices from incompatible dictionaries.
301+
if (!reader.getDictionaryVectors.isEmpty) {
302+
logWarning(
303+
"Unexpected dictionary-encoded column during BroadcastExchange coalescing; " +
304+
"skipping coalesce")
305+
reader.close()
306+
if (targetRoot != null) {
307+
targetRoot.close()
308+
targetRoot = null
309+
}
310+
return (buffers, 0L, 0L)
311+
}
312+
while (reader.loadNextBatch()) {
313+
val sourceRoot = reader.getVectorSchemaRoot
314+
if (targetRoot == null) {
315+
targetRoot = VectorSchemaRoot.create(sourceRoot.getSchema, allocator)
316+
targetRoot.allocateNew()
317+
}
318+
VectorSchemaRootAppender.append(targetRoot, sourceRoot)
319+
totalRows += sourceRoot.getRowCount
320+
batchCount += 1
321+
}
322+
} finally {
323+
reader.close()
324+
}
325+
}
326+
327+
if (targetRoot == null) {
328+
return (Array.empty, 0L, 0L)
329+
}
330+
331+
assert(
332+
targetRoot.getRowCount.toLong == totalRows,
333+
s"Row count mismatch after coalesce: ${targetRoot.getRowCount} != $totalRows")
334+
335+
logInfo(s"Coalesced $batchCount broadcast batches into 1 ($totalRows rows)")
336+
337+
val outputStream = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
338+
val compressedOutputStream =
339+
new DataOutputStream(codec.compressedOutputStream(outputStream))
340+
val writer =
341+
new ArrowStreamWriter(targetRoot, null, Channels.newChannel(compressedOutputStream))
342+
try {
343+
writer.start()
344+
writer.writeBatch()
345+
} finally {
346+
writer.close()
347+
}
348+
349+
(Array(outputStream.toChunkedByteBuffer), batchCount.toLong, totalRows)
350+
} finally {
351+
if (targetRoot != null) {
352+
targetRoot.close()
353+
}
354+
}
355+
} finally {
356+
allocator.close()
357+
}
358+
}
359+
255360
def getBatchFieldVectors(
256361
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
257362
var provider: Option[DictionaryProvider] = None

0 commit comments

Comments
 (0)