@@ -26,13 +26,15 @@ import java.nio.channels.Channels
2626import scala .jdk .CollectionConverters ._
2727
2828import org .apache .arrow .c .CDataDictionaryProvider
29- import org .apache .arrow .vector .{ BigIntVector , BitVector , DateDayVector , DecimalVector , FieldVector , FixedSizeBinaryVector , Float4Vector , Float8Vector , IntVector , NullVector , SmallIntVector , TimeStampMicroTZVector , TimeStampMicroVector , TinyIntVector , ValueVector , VarBinaryVector , VarCharVector , VectorSchemaRoot }
29+ import org .apache .arrow .vector ._
3030import org .apache .arrow .vector .complex .{ListVector , MapVector , StructVector }
3131import org .apache .arrow .vector .dictionary .DictionaryProvider
32- import org .apache .arrow .vector .ipc .ArrowStreamWriter
32+ import org .apache .arrow .vector .ipc .{ ArrowStreamReader , ArrowStreamWriter }
3333import org .apache .arrow .vector .types ._
3434import org .apache .arrow .vector .types .pojo .{ArrowType , Field , FieldType , Schema }
35+ import org .apache .arrow .vector .util .VectorSchemaRootAppender
3536import org .apache .spark .{SparkEnv , SparkException }
37+ import org .apache .spark .internal .Logging
3638import org .apache .spark .io .CompressionCodec
3739import org .apache .spark .sql .comet .execution .arrow .ArrowReaderIterator
3840import org .apache .spark .sql .types ._
@@ -43,7 +45,7 @@ import org.apache.comet.Constants.COMET_CONF_DIR_ENV
4345import org .apache .comet .shims .CometTypeShim
4446import org .apache .comet .vector .CometVector
4547
46- object Utils extends CometTypeShim {
48+ object Utils extends CometTypeShim with Logging {
4749 def getConfPath (confFileName : String ): String = {
4850 sys.env
4951 .get(COMET_CONF_DIR_ENV )
@@ -232,6 +234,7 @@ object Utils extends CometTypeShim {
232234
233235 /**
234236 * Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
237+ *
235238 * @param bytes
236239 * the serialized batches
237240 * @param source
@@ -252,6 +255,108 @@ object Utils extends CometTypeShim {
252255 new ArrowReaderIterator (Channels .newChannel(ins), source)
253256 }
254257
258+ /**
259+ * Coalesces many small Arrow IPC batches into a single batch for broadcasting.
260+ *
261+ * Why this is necessary: The broadcast exchange collects shuffle output by calling
262+ * getByteArrayRdd, which serializes each ColumnarBatch independently into its own
263+ * ChunkedByteBuffer. The shuffle reader (CometBlockStoreShuffleReader) produces one
264+ * ColumnarBatch per shuffle block, and there is one block per writer task per output partition.
265+ * So with W writer tasks and P output partitions, the broadcast collects up to W * P tiny
266+ * batches. For example, with 400 writer tasks and 500 partitions, 1M rows would arrive as ~200K
267+ * batches of ~5 rows each.
268+ *
269+ * Without coalescing, every consumer task in the broadcast join would independently deserialize
270+ * all of these tiny Arrow IPC streams, paying per-stream overhead (schema parsing, buffer
271+ * allocation) for each one. With coalescing, we decode and append all batches into one
272+ * VectorSchemaRoot on the driver, then re-serialize once. Each consumer task then deserializes
273+ * a single Arrow IPC stream.
274+ */
275+ def coalesceBroadcastBatches (
276+ input : Iterator [ChunkedByteBuffer ]): (Array [ChunkedByteBuffer ], Long , Long ) = {
277+ val buffers = input.filterNot(_.size == 0 ).toArray
278+ if (buffers.isEmpty) {
279+ return (Array .empty, 0L , 0L )
280+ }
281+
282+ val allocator = org.apache.comet.CometArrowAllocator
283+ .newChildAllocator(" broadcast-coalesce" , 0 , Long .MaxValue )
284+ try {
285+ var targetRoot : VectorSchemaRoot = null
286+ var totalRows = 0L
287+ var batchCount = 0
288+
289+ val codec = CompressionCodec .createCodec(SparkEnv .get.conf)
290+ try {
291+ for (bytes <- buffers) {
292+ val compressedInputStream =
293+ new DataInputStream (codec.compressedInputStream(bytes.toInputStream()))
294+ val reader =
295+ new ArrowStreamReader (Channels .newChannel(compressedInputStream), allocator)
296+ try {
297+ // Comet decodes dictionaries during execution, so this shouldn't happen.
298+ // If it does, fall back to the original uncoalesced buffers because each
299+ // partition can have a different dictionary, and appending index vectors
300+ // would silently mix indices from incompatible dictionaries.
301+ if (! reader.getDictionaryVectors.isEmpty) {
302+ logWarning(
303+ " Unexpected dictionary-encoded column during BroadcastExchange coalescing; " +
304+ " skipping coalesce" )
305+ reader.close()
306+ if (targetRoot != null ) {
307+ targetRoot.close()
308+ targetRoot = null
309+ }
310+ return (buffers, 0L , 0L )
311+ }
312+ while (reader.loadNextBatch()) {
313+ val sourceRoot = reader.getVectorSchemaRoot
314+ if (targetRoot == null ) {
315+ targetRoot = VectorSchemaRoot .create(sourceRoot.getSchema, allocator)
316+ targetRoot.allocateNew()
317+ }
318+ VectorSchemaRootAppender .append(targetRoot, sourceRoot)
319+ totalRows += sourceRoot.getRowCount
320+ batchCount += 1
321+ }
322+ } finally {
323+ reader.close()
324+ }
325+ }
326+
327+ if (targetRoot == null ) {
328+ return (Array .empty, 0L , 0L )
329+ }
330+
331+ assert(
332+ targetRoot.getRowCount.toLong == totalRows,
333+ s " Row count mismatch after coalesce: ${targetRoot.getRowCount} != $totalRows" )
334+
335+ logInfo(s " Coalesced $batchCount broadcast batches into 1 ( $totalRows rows) " )
336+
337+ val outputStream = new ChunkedByteBufferOutputStream (1024 * 1024 , ByteBuffer .allocate)
338+ val compressedOutputStream =
339+ new DataOutputStream (codec.compressedOutputStream(outputStream))
340+ val writer =
341+ new ArrowStreamWriter (targetRoot, null , Channels .newChannel(compressedOutputStream))
342+ try {
343+ writer.start()
344+ writer.writeBatch()
345+ } finally {
346+ writer.close()
347+ }
348+
349+ (Array (outputStream.toChunkedByteBuffer), batchCount.toLong, totalRows)
350+ } finally {
351+ if (targetRoot != null ) {
352+ targetRoot.close()
353+ }
354+ }
355+ } finally {
356+ allocator.close()
357+ }
358+ }
359+
255360 def getBatchFieldVectors (
256361 batch : ColumnarBatch ): (Seq [FieldVector ], Option [DictionaryProvider ]) = {
257362 var provider : Option [DictionaryProvider ] = None
0 commit comments