Skip to content

Commit 76ea2dd

Browse files
authored
perf: Coalesce broadcast exchange batches before broadcasting (#3703)
1 parent 1c8c873 commit 76ea2dd

4 files changed

Lines changed: 205 additions & 19 deletions

File tree

common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@ import java.nio.channels.Channels
2626
import scala.jdk.CollectionConverters._
2727

2828
import org.apache.arrow.c.CDataDictionaryProvider
29-
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
29+
import org.apache.arrow.vector._
3030
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
3131
import org.apache.arrow.vector.dictionary.DictionaryProvider
32-
import org.apache.arrow.vector.ipc.ArrowStreamWriter
32+
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
3333
import org.apache.arrow.vector.types._
3434
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
35+
import org.apache.arrow.vector.util.VectorSchemaRootAppender
3536
import org.apache.spark.{SparkEnv, SparkException}
37+
import org.apache.spark.internal.Logging
3638
import org.apache.spark.io.CompressionCodec
3739
import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
3840
import org.apache.spark.sql.types._
@@ -43,7 +45,7 @@ import org.apache.comet.Constants.COMET_CONF_DIR_ENV
4345
import org.apache.comet.shims.CometTypeShim
4446
import org.apache.comet.vector.CometVector
4547

46-
object Utils extends CometTypeShim {
48+
object Utils extends CometTypeShim with Logging {
4749
def getConfPath(confFileName: String): String = {
4850
sys.env
4951
.get(COMET_CONF_DIR_ENV)
@@ -232,6 +234,7 @@ object Utils extends CometTypeShim {
232234

233235
/**
234236
* Decodes the byte arrays back to ColumnarBatchs and put them into buffer.
237+
*
235238
* @param bytes
236239
* the serialized batches
237240
* @param source
@@ -252,6 +255,108 @@ object Utils extends CometTypeShim {
252255
new ArrowReaderIterator(Channels.newChannel(ins), source)
253256
}
254257

258+
/**
259+
* Coalesces many small Arrow IPC batches into a single batch for broadcasting.
260+
*
261+
* Why this is necessary: The broadcast exchange collects shuffle output by calling
262+
* getByteArrayRdd, which serializes each ColumnarBatch independently into its own
263+
* ChunkedByteBuffer. The shuffle reader (CometBlockStoreShuffleReader) produces one
264+
* ColumnarBatch per shuffle block, and there is one block per writer task per output partition.
265+
* So with W writer tasks and P output partitions, the broadcast collects up to W * P tiny
266+
* batches. For example, with 400 writer tasks and 500 partitions, 1M rows would arrive as ~200K
267+
* batches of ~5 rows each.
268+
*
269+
* Without coalescing, every consumer task in the broadcast join would independently deserialize
270+
* all of these tiny Arrow IPC streams, paying per-stream overhead (schema parsing, buffer
271+
* allocation) for each one. With coalescing, we decode and append all batches into one
272+
* VectorSchemaRoot on the driver, then re-serialize once. Each consumer task then deserializes
273+
* a single Arrow IPC stream.
274+
*/
275+
def coalesceBroadcastBatches(
276+
input: Iterator[ChunkedByteBuffer]): (Array[ChunkedByteBuffer], Long, Long) = {
277+
val buffers = input.filterNot(_.size == 0).toArray
278+
if (buffers.isEmpty) {
279+
return (Array.empty, 0L, 0L)
280+
}
281+
282+
val allocator = org.apache.comet.CometArrowAllocator
283+
.newChildAllocator("broadcast-coalesce", 0, Long.MaxValue)
284+
try {
285+
var targetRoot: VectorSchemaRoot = null
286+
var totalRows = 0L
287+
var batchCount = 0
288+
289+
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
290+
try {
291+
for (bytes <- buffers) {
292+
val compressedInputStream =
293+
new DataInputStream(codec.compressedInputStream(bytes.toInputStream()))
294+
val reader =
295+
new ArrowStreamReader(Channels.newChannel(compressedInputStream), allocator)
296+
try {
297+
// Comet decodes dictionaries during execution, so this shouldn't happen.
298+
// If it does, fall back to the original uncoalesced buffers because each
299+
// partition can have a different dictionary, and appending index vectors
300+
// would silently mix indices from incompatible dictionaries.
301+
if (!reader.getDictionaryVectors.isEmpty) {
302+
logWarning(
303+
"Unexpected dictionary-encoded column during BroadcastExchange coalescing; " +
304+
"skipping coalesce")
305+
reader.close()
306+
if (targetRoot != null) {
307+
targetRoot.close()
308+
targetRoot = null
309+
}
310+
return (buffers, 0L, 0L)
311+
}
312+
while (reader.loadNextBatch()) {
313+
val sourceRoot = reader.getVectorSchemaRoot
314+
if (targetRoot == null) {
315+
targetRoot = VectorSchemaRoot.create(sourceRoot.getSchema, allocator)
316+
targetRoot.allocateNew()
317+
}
318+
VectorSchemaRootAppender.append(targetRoot, sourceRoot)
319+
totalRows += sourceRoot.getRowCount
320+
batchCount += 1
321+
}
322+
} finally {
323+
reader.close()
324+
}
325+
}
326+
327+
if (targetRoot == null) {
328+
return (Array.empty, 0L, 0L)
329+
}
330+
331+
assert(
332+
targetRoot.getRowCount.toLong == totalRows,
333+
s"Row count mismatch after coalesce: ${targetRoot.getRowCount} != $totalRows")
334+
335+
logInfo(s"Coalesced $batchCount broadcast batches into 1 ($totalRows rows)")
336+
337+
val outputStream = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
338+
val compressedOutputStream =
339+
new DataOutputStream(codec.compressedOutputStream(outputStream))
340+
val writer =
341+
new ArrowStreamWriter(targetRoot, null, Channels.newChannel(compressedOutputStream))
342+
try {
343+
writer.start()
344+
writer.writeBatch()
345+
} finally {
346+
writer.close()
347+
}
348+
349+
(Array(outputStream.toChunkedByteBuffer), batchCount.toLong, totalRows)
350+
} finally {
351+
if (targetRoot != null) {
352+
targetRoot.close()
353+
}
354+
}
355+
} finally {
356+
allocator.close()
357+
}
358+
}
359+
255360
def getBatchFieldVectors(
256361
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
257362
var provider: Option[DictionaryProvider] = None

spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,13 @@ case class CometBroadcastExchangeExec(
7777
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
7878
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"),
7979
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build"),
80-
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"))
80+
"broadcastTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to broadcast"),
81+
"numCoalescedBatches" -> SQLMetrics.createMetric(
82+
sparkContext,
83+
"number of coalesced batches for broadcast"),
84+
"numCoalescedRows" -> SQLMetrics.createMetric(
85+
sparkContext,
86+
"number of coalesced rows for broadcast"))
8187

8288
override def doCanonicalize(): SparkPlan = {
8389
CometBroadcastExchangeExec(null, null, mode, child.canonicalized)
@@ -155,7 +161,14 @@ case class CometBroadcastExchangeExec(
155161
val beforeBuild = System.nanoTime()
156162
longMetric("collectTime") += NANOSECONDS.toMillis(beforeBuild - beforeCollect)
157163

158-
val batches = input.toArray
164+
// Coalesce the many small per-shuffle-block buffers into a single buffer.
165+
// Without this, each consumer task deserializes one Arrow IPC stream per
166+
// shuffle block (one per writer task per partition), which is very expensive
167+
// when there are hundreds of writer tasks and partitions. See the scaladoc
168+
// on coalesceBroadcastBatches for details.
169+
val (batches, coalescedBatches, coalescedRows) = Utils.coalesceBroadcastBatches(input)
170+
longMetric("numCoalescedBatches") += coalescedBatches
171+
longMetric("numCoalescedRows") += coalescedRows
159172

160173
val dataSize = batches.map(_.size).sum
161174

spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, He
3535
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate}
3636
import org.apache.spark.sql.comet._
3737
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec}
38-
import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SparkPlan, SQLExecution, UnionExec}
38+
import org.apache.spark.sql.execution._
3939
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec}
4040
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
4141
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
@@ -474,6 +474,10 @@ class CometExecSuite extends CometTestBase {
474474
val expected = (0 until numParts).flatMap(_ => (0 until 5).map(i => i + 1)).sorted
475475

476476
assert(rowContents === expected)
477+
478+
val metrics = nativeBroadcast.metrics
479+
assert(metrics("numCoalescedBatches").value == 5L)
480+
assert(metrics("numCoalescedRows").value == 5L)
477481
}
478482
}
479483
}
@@ -493,6 +497,10 @@ class CometExecSuite extends CometTestBase {
493497
}.get.asInstanceOf[CometBroadcastExchangeExec]
494498
val rows = nativeBroadcast.executeCollect()
495499
assert(rows.isEmpty)
500+
501+
val metrics = nativeBroadcast.metrics
502+
assert(metrics("numCoalescedBatches").value == 0L)
503+
assert(metrics("numCoalescedRows").value == 0L)
496504
}
497505
}
498506
}
@@ -712,7 +720,7 @@ class CometExecSuite extends CometTestBase {
712720
assert(metrics.contains("build_time"))
713721
assert(metrics("build_time").value > 1L)
714722
assert(metrics.contains("build_input_batches"))
715-
assert(metrics("build_input_batches").value == 25L)
723+
assert(metrics("build_input_batches").value == 5L)
716724
assert(metrics.contains("build_mem_used"))
717725
assert(metrics("build_mem_used").value > 1L)
718726
assert(metrics.contains("build_input_rows"))

spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.internal.SQLConf
3131
import org.apache.comet.CometConf
3232

3333
class CometJoinSuite extends CometTestBase {
34+
3435
import testImplicits._
3536

3637
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
@@ -359,28 +360,87 @@ class CometJoinSuite extends CometTestBase {
359360
checkSparkAnswer(left.join(right, ($"left.N" === $"right.N") && ($"right.N" =!= 3), "full"))
360361

361362
checkSparkAnswer(sql("""
362-
|SELECT l.a, count(*)
363-
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
364-
|GROUP BY l.a
363+
|SELECT l.a, count(*)
364+
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
365+
|GROUP BY l.a
365366
""".stripMargin))
366367

367368
checkSparkAnswer(sql("""
368-
|SELECT r.N, count(*)
369-
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
370-
|GROUP BY r.N
369+
|SELECT r.N, count(*)
370+
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
371+
|GROUP BY r.N
371372
""".stripMargin))
372373

373374
checkSparkAnswer(sql("""
374-
|SELECT l.N, count(*)
375-
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
376-
|GROUP BY l.N
375+
|SELECT l.N, count(*)
376+
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
377+
|GROUP BY l.N
377378
""".stripMargin))
378379

379380
checkSparkAnswer(sql("""
380-
|SELECT r.a, count(*)
381-
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
382-
|GROUP BY r.a
381+
|SELECT r.a, count(*)
382+
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
383+
|GROUP BY r.a
383384
""".stripMargin))
384385
}
385386
}
387+
388+
test("Broadcast hash join build-side batch coalescing") {
389+
// Use many shuffle partitions to produce many small broadcast batches,
390+
// then verify that coalescing reduces the build-side batch count to 1 per task.
391+
val numPartitions = 512
392+
withSQLConf(
393+
CometConf.COMET_BATCH_SIZE.key -> "100",
394+
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
395+
"spark.sql.join.forceApplyShuffledHashJoin" -> "true",
396+
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
397+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
398+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
399+
SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) {
400+
withParquetTable((0 until 10000).map(i => (i, i % 5)), "tbl_a") {
401+
withParquetTable((0 until 10000).map(i => (i % 10, i + 2)), "tbl_b") {
402+
// Force a shuffle on tbl_a before broadcast so the broadcast source has
403+
// numPartitions partitions, not just the number of parquet files.
404+
val query =
405+
s"""SELECT /*+ BROADCAST(a) */ *
406+
|FROM (SELECT /*+ REPARTITION($numPartitions) */ * FROM tbl_a) a
407+
|JOIN tbl_b ON a._2 = tbl_b._1""".stripMargin
408+
409+
val (_, cometPlan) = checkSparkAnswerAndOperator(
410+
sql(query),
411+
Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec]))
412+
413+
val joins = collect(cometPlan) { case j: CometBroadcastHashJoinExec =>
414+
j
415+
}
416+
assert(joins.nonEmpty, "Expected CometBroadcastHashJoinExec in plan")
417+
418+
val join = joins.head
419+
val buildBatches = join.metrics("build_input_batches").value
420+
421+
// Without coalescing, build_input_batches would be ~numPartitions per task,
422+
// totaling ~numPartitions * numPartitions across all tasks.
423+
// With coalescing, each task gets 1 batch, so total ≈ numPartitions.
424+
assert(
425+
buildBatches <= numPartitions,
426+
s"Expected at most $numPartitions build batches (1 per task), got $buildBatches. " +
427+
"Broadcast batch coalescing may not be working.")
428+
429+
val broadcasts = collect(cometPlan) { case b: CometBroadcastExchangeExec =>
430+
b
431+
}
432+
assert(broadcasts.nonEmpty, "Expected CometBroadcastExchangeExec in plan")
433+
434+
val broadcast = broadcasts.head
435+
val coalescedBatches = broadcast.metrics("numCoalescedBatches").value
436+
val coalescedRows = broadcast.metrics("numCoalescedRows").value
437+
438+
assert(
439+
coalescedBatches >= numPartitions,
440+
s"Expected at least $numPartitions coalesced batches, got $coalescedBatches")
441+
assert(coalescedRows == 10000, s"Expected 10000 coalesced rows, got $coalescedRows")
442+
}
443+
}
444+
}
445+
}
386446
}

0 commit comments

Comments
 (0)