@@ -32,7 +32,9 @@ import org.apache.arrow.vector.dictionary.DictionaryProvider
3232import org .apache .arrow .vector .ipc .ArrowStreamWriter
3333import org .apache .arrow .vector .types ._
3434import org .apache .arrow .vector .types .pojo .{ArrowType , Field , FieldType , Schema }
35+ import org .apache .arrow .vector .util .VectorSchemaRootAppender
3536import org .apache .spark .{SparkEnv , SparkException }
37+ import org .apache .spark .internal .Logging
3638import org .apache .spark .io .CompressionCodec
3739import org .apache .spark .sql .comet .execution .arrow .ArrowReaderIterator
3840import org .apache .spark .sql .types ._
@@ -43,7 +45,7 @@ import org.apache.comet.Constants.COMET_CONF_DIR_ENV
4345import org .apache .comet .shims .CometTypeShim
4446import org .apache .comet .vector .CometVector
4547
46- object Utils extends CometTypeShim {
48+ object Utils extends CometTypeShim with Logging {
4749 def getConfPath (confFileName : String ): String = {
4850 sys.env
4951 .get(COMET_CONF_DIR_ENV )
@@ -252,6 +254,75 @@ object Utils extends CometTypeShim {
252254 new ArrowReaderIterator (Channels .newChannel(ins), source)
253255 }
254256
257+ /**
258+ * Coalesces many small ChunkedByteBuffers (one per source partition) into a single
259+ * ChunkedByteBuffer containing one Arrow IPC stream with one record batch. This avoids each
260+ * consumer partition having to deserialize N separate streams.
261+ */
262+ def coalesceBroadcastBatches (input : Iterator [ChunkedByteBuffer ]): Array [ChunkedByteBuffer ] = {
263+ val decoded = input.flatMap(decodeBatches(_, " broadcast-coalesce" )).toArray
264+ if (decoded.isEmpty) {
265+ return Array .empty
266+ }
267+
268+ try {
269+ var hasDictionary = false
270+ val sourceRoots = decoded.map { batch =>
271+ val (fieldVectors, providerOpt) = getBatchFieldVectors(batch)
272+ if (providerOpt.isDefined) {
273+ hasDictionary = true
274+ }
275+ new VectorSchemaRoot (fieldVectors.asJava)
276+ }
277+
278+ // Fall back to per-batch serialization if any batch has dictionary-encoded vectors,
279+ // since merging dictionaries across batches is not supported.
280+ if (hasDictionary) {
281+ logInfo(
282+ s " Broadcast coalesce falling back to per-batch serialization due to " +
283+ s " dictionary-encoded vectors ( ${decoded.length} batches) " )
284+ return decoded.flatMap { batch =>
285+ serializeBatches(Iterator (batch)).map(_._2)
286+ }
287+ }
288+
289+ val allocator = org.apache.comet.CometArrowAllocator
290+ .newChildAllocator(" broadcast-coalesce" , 0 , Long .MaxValue )
291+ try {
292+ val schema = sourceRoots.head.getSchema
293+ val targetRoot = VectorSchemaRoot .create(schema, allocator)
294+ try {
295+ VectorSchemaRootAppender .append(targetRoot, sourceRoots : _* )
296+
297+ val expectedRows = decoded.map(_.numRows().toLong).sum
298+ assert(
299+ targetRoot.getRowCount.toLong == expectedRows,
300+ s " Row count mismatch after coalesce: ${targetRoot.getRowCount} != $expectedRows" )
301+
302+ logInfo(
303+ s " Coalesced ${decoded.length} broadcast batches into 1 " +
304+ s " ( $expectedRows rows) " )
305+
306+ val codec = CompressionCodec .createCodec(SparkEnv .get.conf)
307+ val cbbos = new ChunkedByteBufferOutputStream (1024 * 1024 , ByteBuffer .allocate)
308+ val out = new DataOutputStream (codec.compressedOutputStream(cbbos))
309+ val writer = new ArrowStreamWriter (targetRoot, null , Channels .newChannel(out))
310+ writer.start()
311+ writer.writeBatch()
312+ writer.close()
313+
314+ Array (cbbos.toChunkedByteBuffer)
315+ } finally {
316+ targetRoot.close()
317+ }
318+ } finally {
319+ allocator.close()
320+ }
321+ } finally {
322+ decoded.foreach(_.close())
323+ }
324+ }
325+
255326 def getBatchFieldVectors (
256327 batch : ColumnarBatch ): (Seq [FieldVector ], Option [DictionaryProvider ]) = {
257328 var provider : Option [DictionaryProvider ] = None
0 commit comments