Skip to content

Commit 96424a7

Browse files
committed
Coalesce broadcast exchange batches before broadcasting
CometBroadcastExchangeExec previously broadcast an Array[ChunkedByteBuffer] with one entry per source partition. Each consumer partition independently deserialized all entries, creating a separate compression codec and Arrow IPC reader per entry. For broadcasts with many source partitions, this produced large per-task overhead in the hash join build-side collection. Decode and concatenate all broadcast batches into a single ChunkedByteBuffer on the driver using VectorSchemaRootAppender before broadcasting. Falls back to per-batch serialization if dictionary-encoded vectors are present.
1 parent 9b773f3 commit 96424a7

2 files changed

Lines changed: 76 additions & 2 deletions

File tree

common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ import org.apache.arrow.vector.dictionary.DictionaryProvider
3232
import org.apache.arrow.vector.ipc.ArrowStreamWriter
3333
import org.apache.arrow.vector.types._
3434
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
35+
import org.apache.arrow.vector.util.VectorSchemaRootAppender
3536
import org.apache.spark.{SparkEnv, SparkException}
37+
import org.apache.spark.internal.Logging
3638
import org.apache.spark.io.CompressionCodec
3739
import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
3840
import org.apache.spark.sql.types._
@@ -43,7 +45,7 @@ import org.apache.comet.Constants.COMET_CONF_DIR_ENV
4345
import org.apache.comet.shims.CometTypeShim
4446
import org.apache.comet.vector.CometVector
4547

46-
object Utils extends CometTypeShim {
48+
object Utils extends CometTypeShim with Logging {
4749
def getConfPath(confFileName: String): String = {
4850
sys.env
4951
.get(COMET_CONF_DIR_ENV)
@@ -252,6 +254,75 @@ object Utils extends CometTypeShim {
252254
new ArrowReaderIterator(Channels.newChannel(ins), source)
253255
}
254256

257+
/**
258+
* Coalesces many small ChunkedByteBuffers (one per source partition) into a single
259+
* ChunkedByteBuffer containing one Arrow IPC stream with one record batch. This avoids each
260+
* consumer partition having to deserialize N separate streams.
261+
*/
262+
def coalesceBroadcastBatches(input: Iterator[ChunkedByteBuffer]): Array[ChunkedByteBuffer] = {
263+
val decoded = input.flatMap(decodeBatches(_, "broadcast-coalesce")).toArray
264+
if (decoded.isEmpty) {
265+
return Array.empty
266+
}
267+
268+
try {
269+
var hasDictionary = false
270+
val sourceRoots = decoded.map { batch =>
271+
val (fieldVectors, providerOpt) = getBatchFieldVectors(batch)
272+
if (providerOpt.isDefined) {
273+
hasDictionary = true
274+
}
275+
new VectorSchemaRoot(fieldVectors.asJava)
276+
}
277+
278+
// Fall back to per-batch serialization if any batch has dictionary-encoded vectors,
279+
// since merging dictionaries across batches is not supported.
280+
if (hasDictionary) {
281+
logInfo(
282+
s"Broadcast coalesce falling back to per-batch serialization due to " +
283+
s"dictionary-encoded vectors (${decoded.length} batches)")
284+
return decoded.flatMap { batch =>
285+
serializeBatches(Iterator(batch)).map(_._2)
286+
}
287+
}
288+
289+
val allocator = org.apache.comet.CometArrowAllocator
290+
.newChildAllocator("broadcast-coalesce", 0, Long.MaxValue)
291+
try {
292+
val schema = sourceRoots.head.getSchema
293+
val targetRoot = VectorSchemaRoot.create(schema, allocator)
294+
try {
295+
VectorSchemaRootAppender.append(targetRoot, sourceRoots: _*)
296+
297+
val expectedRows = decoded.map(_.numRows().toLong).sum
298+
assert(
299+
targetRoot.getRowCount.toLong == expectedRows,
300+
s"Row count mismatch after coalesce: ${targetRoot.getRowCount} != $expectedRows")
301+
302+
logInfo(
303+
s"Coalesced ${decoded.length} broadcast batches into 1 " +
304+
s"($expectedRows rows)")
305+
306+
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
307+
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
308+
val out = new DataOutputStream(codec.compressedOutputStream(cbbos))
309+
val writer = new ArrowStreamWriter(targetRoot, null, Channels.newChannel(out))
310+
writer.start()
311+
writer.writeBatch()
312+
writer.close()
313+
314+
Array(cbbos.toChunkedByteBuffer)
315+
} finally {
316+
targetRoot.close()
317+
}
318+
} finally {
319+
allocator.close()
320+
}
321+
} finally {
322+
decoded.foreach(_.close())
323+
}
324+
}
325+
255326
def getBatchFieldVectors(
256327
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
257328
var provider: Option[DictionaryProvider] = None

spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ case class CometBroadcastExchangeExec(
155155
val beforeBuild = System.nanoTime()
156156
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)
157157

158-
val batches = input.toArray
158+
// Coalesce many small per-partition buffers into a single buffer so each
159+
// consumer partition only deserializes one Arrow IPC stream.
160+
// May produce multiple buffers if dictionary-encoded vectors are present.
161+
val batches = Utils.coalesceBroadcastBatches(input)
159162

160163
val dataSize = batches.map(_.size).sum
161164

0 commit comments

Comments
 (0)