@@ -89,7 +89,7 @@ case class CometPlanner(session: SparkSession) extends Rule[SparkPlan] with Logg
8989 // Phase 1 must run before BroadcastConsumerIndex.build: Phase 1's generic-exec prediction
9090 // reads children's LIKELY_COMET tags (post-order walk), and the index reads BHJ tags to
9191 // decide which broadcasts have a Comet consumer.
92- val annotated1 = phase1LikelyComet(
92+ val annotated1 : SparkPlan = phase1LikelyComet(
9393 prepared,
9494 PlanningContext (
9595 session = session,
@@ -107,8 +107,9 @@ case class CometPlanner(session: SparkSession) extends Rule[SparkPlan] with Logg
107107
108108 val annotated2 = phase2Decision(annotated1, context)
109109 val emitted = phase3Emit(annotated2, context)
110- val reverted = revertOrphanedBroadcasts(emitted)
111- val cleaned = cleanupLogicalLinks(reverted)
110+ val broadcastsReverted = revertBroadcastsWithoutCometConsumer(emitted)
111+ val shufflesReverted = revertRedundantColumnarShuffle(broadcastsReverted)
112+ val cleaned = cleanupLogicalLinks(shufflesReverted)
112113 val blocked = convertBlocks(cleaned)
113114 val finalPlan = postPass(blocked, context)
114115
@@ -179,7 +180,7 @@ case class CometPlanner(session: SparkSession) extends Rule[SparkPlan] with Logg
179180 * Shuffle doesn't need the equivalent revert because a Spark parent with a Comet columnar
180181 * shuffle child is handled naturally by Spark's transition insertion.
181182 */
182- private def revertOrphanedBroadcasts (plan : SparkPlan ): SparkPlan = {
183+ private def revertBroadcastsWithoutCometConsumer (plan : SparkPlan ): SparkPlan = {
183184 if (CometConf .COMET_EXEC_BROADCAST_FORCE_ENABLED .get()) {
184185 return plan
185186 }
@@ -199,6 +200,50 @@ case class CometPlanner(session: SparkSession) extends Rule[SparkPlan] with Logg
199200 out
200201 }
201202
203+ /**
204+ * Revert a `CometShuffleExchangeExec` with `CometColumnarShuffle` whose parent and child are
205+ * both non-Comet `HashAggregateExec` / `ObjectHashAggregateExec` back to the original Spark
206+ * `ShuffleExchangeExec`. Mirrors the legacy `revertRedundantColumnarShuffle` (PR #4010): the
207+ * partial-final-aggregate pattern where both aggregates fall back to Spark would otherwise keep
208+ * a columnar shuffle between them, adding row->arrow->shuffle->arrow->row conversion with no
209+ * Comet consumer on either side.
210+ *
211+ * Phase 1's optimistic-true prediction for shuffles allows the legitimate
212+ * `Sort-over-Spark-leaf` pattern to convert (the shuffle does row->arrow at exchange time). The
213+ * same optimism produces the redundant pattern when both ends remain Spark, which this pass
214+ * cleans up. Narrow match on aggregate-shuffle-aggregate keeps the intervention surgical; other
215+ * Spark-Comet-Spark sandwiches are handled by `revertBroadcastsWithoutCometConsumer` or Spark's
216+ * transition insertion.
217+ */
218+ private def revertRedundantColumnarShuffle (plan : SparkPlan ): SparkPlan = {
219+ def isAggregate (p : SparkPlan ): Boolean =
220+ p.isInstanceOf [org.apache.spark.sql.execution.aggregate.HashAggregateExec ] ||
221+ p.isInstanceOf [org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec ]
222+
223+ def isRedundantShuffle (child : SparkPlan ): Boolean = child match {
224+ case s : CometShuffleExchangeExec =>
225+ s.shuffleType == org.apache.spark.sql.comet.execution.shuffle.CometColumnarShuffle &&
226+ isAggregate(s.child)
227+ case _ => false
228+ }
229+
230+ var reverted = 0
231+ val out = plan.transform {
232+ case op if isAggregate(op) && op.children.exists(isRedundantShuffle) =>
233+ val newChildren = op.children.map {
234+ case s : CometShuffleExchangeExec
235+ if s.shuffleType == org.apache.spark.sql.comet.execution.shuffle.CometColumnarShuffle
236+ && isAggregate(s.child) =>
237+ reverted += 1
238+ s.originalPlan.withNewChildren(Seq (s.child))
239+ case other => other
240+ }
241+ op.withNewChildren(newChildren)
242+ }
243+ if (reverted > 0 ) logDebug(s " CometPlanner: reverted $reverted redundant columnar shuffles " )
244+ out
245+ }
246+
202247 private def isNativeCompatible (node : SparkPlan ): Boolean =
203248 node.isInstanceOf [CometNativeExec ] || node.getTagValue(CometTags .NATIVE_OP ).isDefined
204249
0 commit comments