Skip to content

Commit 63fd9ac

Browse files
andygrovewForget
andauthored
fix: [branch-0.10] Avoid spark plan execution cache preventing CometBatchRDD numPartitions change (#2420) (#2503)
* fix merge conflicts * trigger build --------- Co-authored-by: Zhen Wang <643348094@qq.com>
1 parent 21ddb82 commit 63fd9ac

2 files changed

Lines changed: 19 additions & 13 deletions

File tree

spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import scala.concurrent.{ExecutionContext, Promise}
2626
import scala.concurrent.duration.NANOSECONDS
2727
import scala.util.control.NonFatal
2828

29-
import org.apache.spark.{broadcast, Partition, SparkContext, TaskContext}
29+
import org.apache.spark.{broadcast, Partition, SparkContext, SparkException, TaskContext}
3030
import org.apache.spark.rdd.RDD
3131
import org.apache.spark.sql.catalyst.InternalRow
3232
import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -102,14 +102,8 @@ case class CometBroadcastExchangeExec(
102102
@transient
103103
private lazy val maxBroadcastRows = 512000000
104104

105-
private var numPartitions: Option[Int] = None
106-
107-
def setNumPartitions(numPartitions: Int): CometBroadcastExchangeExec = {
108-
this.numPartitions = Some(numPartitions)
109-
this
110-
}
111105
def getNumPartitions(): Int = {
112-
numPartitions.getOrElse(child.executeColumnar().getNumPartitions)
106+
child.executeColumnar().getNumPartitions
113107
}
114108

115109
@transient
@@ -224,6 +218,18 @@ case class CometBroadcastExchangeExec(
224218
new CometBatchRDD(sparkContext, getNumPartitions(), broadcasted)
225219
}
226220

221+
// After https://issues.apache.org/jira/browse/SPARK-48195, Spark plan will cache created RDD.
222+
// Since we may change the number of partitions in CometBatchRDD,
223+
// we need a method that always creates a new CometBatchRDD.
224+
def executeColumnar(numPartitions: Int): RDD[ColumnarBatch] = {
225+
if (isCanonicalizedPlan) {
226+
throw SparkException.internalError("A canonicalized plan is not supposed to be executed.")
227+
}
228+
229+
val broadcasted = executeBroadcast[Array[ChunkedByteBuffer]]()
230+
new CometBatchRDD(sparkContext, numPartitions, broadcasted)
231+
}
232+
227233
override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
228234
try {
229235
relationFuture.get(timeout, TimeUnit.SECONDS).asInstanceOf[broadcast.Broadcast[T]]
@@ -276,7 +282,7 @@ object CometBroadcastExchangeExec {
276282
*/
277283
class CometBatchRDD(
278284
sc: SparkContext,
279-
numPartitions: Int,
285+
val numPartitions: Int,
280286
value: broadcast.Broadcast[Array[ChunkedByteBuffer]])
281287
extends RDD[ColumnarBatch](sc, Nil) {
282288

spark/src/main/scala/org/apache/spark/sql/comet/operators.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,16 +272,16 @@ abstract class CometNativeExec extends CometExec {
272272
sparkPlans.zipWithIndex.foreach { case (plan, idx) =>
273273
plan match {
274274
case c: CometBroadcastExchangeExec =>
275-
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
275+
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
276276
case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) =>
277-
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
277+
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
278278
case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
279-
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
279+
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
280280
case BroadcastQueryStageExec(
281281
_,
282282
ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
283283
_) =>
284-
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
284+
inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions)
285285
case _: CometNativeExec =>
286286
// no-op
287287
case _ if idx == firstNonBroadcastPlan.get._2 =>

0 commit comments

Comments
 (0)