Skip to content

Commit e0c6cd5

Browse files
committed
perf: optimize broadcast hash join with CollectLeft mode and decompression caching
Use PartitionMode::CollectLeft instead of Partitioned for broadcast hash joins so DataFusion can optimize hash table construction for the broadcast side. Also cache decompressed broadcast data at executor level to avoid repeated LZ4 decompression across tasks.
1 parent 3d63168 commit e0c6cd5

5 files changed

Lines changed: 186 additions & 6 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,14 +1663,20 @@ impl PhysicalPlanner {
16631663
let left = Arc::clone(&join_params.left.native_plan);
16641664
let right = Arc::clone(&join_params.right.native_plan);
16651665

1666+
let partition_mode = if join.is_broadcast {
1667+
PartitionMode::CollectLeft
1668+
} else {
1669+
PartitionMode::Partitioned
1670+
};
1671+
16661672
let hash_join = Arc::new(HashJoinExec::try_new(
16671673
left,
16681674
right,
16691675
join_params.join_on,
16701676
join_params.join_filter,
16711677
&join_params.join_type,
16721678
None,
1673-
PartitionMode::Partitioned,
1679+
partition_mode,
16741680
// null doesn't equal to null in Spark join key. If the join key is
16751681
// `EqualNullSafe`, Spark will rewrite it during planning.
16761682
NullEquality::NullEqualsNothing,
@@ -1688,7 +1694,7 @@ impl PhysicalPlanner {
16881694
))
16891695
} else {
16901696
let swapped_hash_join =
1691-
hash_join.as_ref().swap_inputs(PartitionMode::Partitioned)?;
1697+
hash_join.as_ref().swap_inputs(partition_mode)?;
16921698

16931699
let mut additional_native_plans = vec![];
16941700
if swapped_hash_join.as_any().is::<ProjectionExec>() {
@@ -3905,6 +3911,7 @@ mod tests {
39053911
join_type: 0,
39063912
condition: None,
39073913
build_side: 0,
3914+
is_broadcast: false,
39083915
})),
39093916
};
39103917

native/proto/src/proto/operator.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ message HashJoin {
334334
JoinType join_type = 3;
335335
optional spark.spark_expression.Expr condition = 4;
336336
BuildSide build_side = 5;
337+
bool is_broadcast = 6;
337338
}
338339

339340
message SortMergeJoin {

spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,23 @@
1919

2020
package org.apache.spark.sql.comet
2121

22+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream}
23+
import java.nio.channels.Channels
2224
import java.util.UUID
23-
import java.util.concurrent.{Future, TimeoutException, TimeUnit}
25+
import java.util.concurrent.{ConcurrentHashMap, Future, TimeoutException, TimeUnit}
2426

2527
import scala.concurrent.{ExecutionContext, Promise}
2628
import scala.concurrent.duration.NANOSECONDS
2729
import scala.util.control.NonFatal
2830

29-
import org.apache.spark.{broadcast, Partition, SparkContext, SparkException, TaskContext}
31+
import org.apache.spark.{broadcast, Partition, SparkContext, SparkEnv, SparkException, TaskContext}
32+
import org.apache.spark.io.CompressionCodec
3033
import org.apache.spark.rdd.RDD
3134
import org.apache.spark.sql.catalyst.InternalRow
3235
import org.apache.spark.sql.catalyst.expressions.Attribute
3336
import org.apache.spark.sql.catalyst.plans.logical.Statistics
3437
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
38+
import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
3539
import org.apache.spark.sql.comet.util.Utils
3640
import org.apache.spark.sql.errors.QueryExecutionErrors
3741
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
@@ -311,8 +315,46 @@ class CometBatchRDD(
311315

312316
override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
313317
val partition = split.asInstanceOf[CometBatchPartition]
314-
partition.value.value.toIterator
315-
.flatMap(Utils.decodeBatches(_, this.getClass.getSimpleName))
318+
val broadcastId = partition.value.id
319+
val decompressedBytes = CometBatchRDD.decompressedCache.computeIfAbsent(
320+
broadcastId,
321+
_ => {
322+
val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
323+
partition.value.value.map { chunkedBuffer =>
324+
val cbbis = chunkedBuffer.toInputStream()
325+
val ins = codec.compressedInputStream(cbbis)
326+
val baos = new ByteArrayOutputStream()
327+
val buf = new Array[Byte](8192)
328+
var n = ins.read(buf)
329+
while (n != -1) {
330+
baos.write(buf, 0, n)
331+
n = ins.read(buf)
332+
}
333+
ins.close()
334+
baos.toByteArray
335+
}
336+
})
337+
decompressedBytes.iterator.flatMap { bytes =>
338+
new ArrowReaderIterator(
339+
Channels.newChannel(new ByteArrayInputStream(bytes)),
340+
this.getClass.getSimpleName)
341+
}
342+
}
343+
}
344+
345+
object CometBatchRDD {
346+
347+
/**
348+
* Executor-level cache of decompressed broadcast data keyed by broadcast ID. This avoids
349+
* repeated LZ4 decompression when multiple tasks on the same executor process the same
350+
* broadcast relation. Each entry stores decompressed Arrow IPC byte arrays.
351+
*/
352+
private[comet] val decompressedCache =
353+
new ConcurrentHashMap[Long, Array[Array[Byte]]]()
354+
355+
/** Invalidate cached decompressed data for a broadcast. */
356+
def invalidateCache(broadcastId: Long): Unit = {
357+
decompressedCache.remove(broadcastId)
316358
}
317359
}
318360

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,6 +1674,7 @@ trait CometHashJoin {
16741674
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
16751675
.setBuildSide(if (join.buildSide == BuildLeft) OperatorOuterClass.BuildSide.BuildLeft
16761676
else OperatorOuterClass.BuildSide.BuildRight)
1677+
.setIsBroadcast(join.isInstanceOf[BroadcastHashJoinExec])
16771678
condition.foreach(joinBuilder.setCondition)
16781679
Some(builder.setHashJoin(joinBuilder).build())
16791680
} else {
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.benchmark
21+
22+
import org.apache.spark.SparkConf
23+
import org.apache.spark.benchmark.Benchmark
24+
import org.apache.spark.sql.SparkSession
25+
import org.apache.spark.sql.internal.SQLConf
26+
27+
import org.apache.comet.{CometConf, CometSparkSessionExtensions}
28+
29+
/**
30+
* Benchmark to measure Comet broadcast hash join performance. To run this benchmark:
31+
* `SPARK_GENERATE_BENCHMARK_FILES=1 make
32+
* benchmark-org.apache.spark.sql.benchmark.CometBroadcastHashJoinBenchmark` Results will be
33+
* written to "spark/benchmarks/CometBroadcastHashJoinBenchmark-**results.txt".
34+
*/
35+
object CometBroadcastHashJoinBenchmark extends CometBenchmarkBase {
36+
override def getSparkSession: SparkSession = {
37+
val conf = new SparkConf()
38+
.setAppName("CometBroadcastHashJoinBenchmark")
39+
.set("spark.master", "local[5]")
40+
.setIfMissing("spark.driver.memory", "3g")
41+
.setIfMissing("spark.executor.memory", "3g")
42+
.set("spark.executor.memoryOverhead", "10g")
43+
44+
val sparkSession = SparkSession.builder
45+
.config(conf)
46+
.withExtensions(new CometSparkSessionExtensions)
47+
.getOrCreate()
48+
49+
sparkSession.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true")
50+
sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true")
51+
sparkSession.conf.set(CometConf.COMET_ENABLED.key, "false")
52+
sparkSession.conf.set(CometConf.COMET_EXEC_ENABLED.key, "false")
53+
sparkSession.conf.set("parquet.enable.dictionary", "false")
54+
55+
sparkSession
56+
}
57+
58+
def broadcastHashJoinBenchmark(
59+
streamedRows: Int,
60+
broadcastRows: Int,
61+
joinType: String): Unit = {
62+
val benchmark = new Benchmark(
63+
s"Broadcast Hash Join ($joinType, stream=$streamedRows, broadcast=$broadcastRows)",
64+
streamedRows,
65+
output = output)
66+
67+
withTempPath { dir =>
68+
import spark.implicits._
69+
70+
// Create streamed (large) table
71+
val streamedDir = dir.getCanonicalPath + "/streamed"
72+
spark
73+
.range(streamedRows)
74+
.select(($"id" % broadcastRows).as("key"), $"id".as("value"))
75+
.write
76+
.mode("overwrite")
77+
.parquet(streamedDir)
78+
79+
// Create broadcast (small) table
80+
val broadcastDir = dir.getCanonicalPath + "/broadcast"
81+
spark
82+
.range(broadcastRows)
83+
.select($"id".as("key"), ($"id" * 10).as("payload"))
84+
.write
85+
.mode("overwrite")
86+
.parquet(broadcastDir)
87+
88+
spark.read.parquet(streamedDir).createOrReplaceTempView("streamed")
89+
spark.read.parquet(broadcastDir).createOrReplaceTempView("broadcast")
90+
91+
val query =
92+
s"SELECT /*+ BROADCAST(broadcast) */ s.value, b.payload " +
93+
s"FROM streamed s $joinType JOIN broadcast b ON s.key = b.key"
94+
95+
withTempTable("streamed", "broadcast") {
96+
benchmark.addCase("Spark") { _ =>
97+
withSQLConf(
98+
CometConf.COMET_ENABLED.key -> "false",
99+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> s"${256 * 1024 * 1024}") {
100+
spark.sql(query).noop()
101+
}
102+
}
103+
104+
benchmark.addCase("Comet (Scan + Exec)") { _ =>
105+
withSQLConf(
106+
CometConf.COMET_ENABLED.key -> "true",
107+
CometConf.COMET_EXEC_ENABLED.key -> "true",
108+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> s"${256 * 1024 * 1024}") {
109+
spark.sql(query).noop()
110+
}
111+
}
112+
113+
benchmark.run()
114+
}
115+
}
116+
}
117+
118+
override def runCometBenchmark(mainArgs: Array[String]): Unit = {
119+
val streamedRows = 2 * 1024 * 1024
120+
val broadcastRows = 1000
121+
122+
for (joinType <- Seq("INNER", "LEFT", "RIGHT")) {
123+
broadcastHashJoinBenchmark(streamedRows, broadcastRows, joinType)
124+
}
125+
126+
// Test with larger broadcast table
127+
broadcastHashJoinBenchmark(streamedRows, 10000, "INNER")
128+
}
129+
}

0 commit comments

Comments
 (0)