Skip to content

Commit 858c60b

Browse files
committed
Stage based fallback
1 parent e79183e commit 858c60b

4 files changed

Lines changed: 439 additions & 3 deletions

File tree

spark/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,33 @@ object CometConf extends ShimCometConf {
442442
.booleanConf
443443
.createWithDefault(true)
444444

445+
val COMET_EXEC_TRANSITION_REVERT_ENABLED: ConfigEntry[Boolean] =
446+
conf(s"$COMET_EXEC_CONFIG_PREFIX.transitionRevert.enabled")
447+
.category(CATEGORY_EXEC)
448+
.doc(
449+
"When enabled, Comet reverts a query stage to Spark row-based execution if the number " +
450+
"of columnar-to-row and row-to-columnar transition pairs exceeds the configured " +
451+
"threshold. This avoids the overhead of repeated format conversions in stages where " +
452+
"many operators fall back to row-based execution.")
453+
.booleanConf
454+
.createWithDefault(true)
455+
456+
val COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS: ConfigEntry[Int] =
457+
conf(s"$COMET_EXEC_CONFIG_PREFIX.transitionRevert.maxTransitions")
458+
.category(CATEGORY_EXEC)
459+
.doc(
460+
"The maximum number of columnar-to-row (C2R) transitions allowed in a single query " +
461+
"stage before Comet reverts the entire stage to Spark row-based execution. When " +
462+
"columnar shuffle is enabled, each C2R has a corresponding row-to-columnar (R2C) " +
463+
"conversion to feed back into the columnar shuffle, so the count reflects full " +
464+
"round-trips. Minimum value is 2 because reverting a stage that feeds a columnar " +
465+
"shuffle still requires at least one R2C at the shuffle boundary. " +
466+
"Only effective when spark.comet.exec.transitionRevert.enabled is true.")
467+
.intConf
468+
.checkValue(_ >= 2, "Must be >= 2. A reverted stage still requires at least one " +
469+
"R2C at the columnar shuffle boundary.")
470+
.createWithDefault(2)
471+
445472
val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] =
446473
conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec")
447474
.category(CATEGORY_SHUFFLE)

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.execution._
3232
import org.apache.spark.sql.internal.SQLConf
3333

3434
import org.apache.comet.CometConf._
35-
import org.apache.comet.rules.{CometExecRule, CometPlanAdaptiveDynamicPruningFilters, CometReuseSubquery, CometScanRule, CometSpark34AqeDppFallbackRule, EliminateRedundantTransitions}
35+
import org.apache.comet.rules.{CometExecRule, CometPlanAdaptiveDynamicPruningFilters, CometReuseSubquery, CometScanRule, CometSpark34AqeDppFallbackRule, EliminateRedundantTransitions, RevertNativeForTransitionHeavyStages}
3636
import org.apache.comet.shims.ShimCometSparkSessionExtensions
3737

3838
/**
@@ -106,8 +106,12 @@ class CometSparkSessionExtensions
106106
case class CometExecColumnar(session: SparkSession) extends ColumnarRule {
107107
override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session)
108108

109-
override def postColumnarTransitions: Rule[SparkPlan] =
110-
EliminateRedundantTransitions(session)
109+
override def postColumnarTransitions: Rule[SparkPlan] = {
110+
val rules = Seq(
111+
EliminateRedundantTransitions(session),
112+
RevertNativeForTransitionHeavyStages(session))
113+
plan => rules.foldLeft(plan) { case (p, rule) => rule(p) }
114+
}
111115
}
112116
}
113117

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.rules
21+
22+
import org.apache.spark.internal.Logging
23+
import org.apache.spark.sql.SparkSession
24+
import org.apache.spark.sql.catalyst.rules.Rule
25+
import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometExec, CometNativeColumnarToRowExec, CometSparkToColumnarExec}
26+
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
27+
import org.apache.spark.sql.execution.{ColumnarToRowExec, ColumnarToRowTransition, RowToColumnarExec, SparkPlan}
28+
import org.apache.spark.sql.execution.adaptive.QueryStageExec
29+
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike}
30+
31+
import org.apache.comet.CometConf
32+
33+
/**
34+
* Reverts a query stage to Spark row-based execution when it has too many columnar-to-row (C2R)
35+
* transitions. Each C2R indicates Comet could not keep execution columnar and had to fall back.
36+
* With columnar shuffle enabled, each C2R implies a corresponding R2C round-trip.
37+
*
38+
*/
39+
case class RevertNativeForTransitionHeavyStages(session: SparkSession)
40+
extends Rule[SparkPlan]
41+
with Logging {
42+
43+
private lazy val enabled = CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.get()
44+
private lazy val maxTransitions = CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.get()
45+
46+
override def apply(plan: SparkPlan): SparkPlan = {
47+
if (!enabled) return plan
48+
49+
if (session.sessionState.conf.adaptiveExecutionEnabled) {
50+
applyForAQE(plan)
51+
} else {
52+
applyForNonAQE(plan)
53+
}
54+
}
55+
56+
private def applyForAQE(plan: SparkPlan): SparkPlan = {
57+
plan match {
58+
case _: BroadcastExchangeLike => plan
59+
case exchange: ShuffleExchangeLike =>
60+
revertStageIfNeeded(exchange.child, exchange.supportsColumnar)
61+
.map(reverted => exchange.withNewChildren(Seq(reverted)))
62+
.getOrElse(plan)
63+
case _ =>
64+
revertStageIfNeeded(plan, outputColumnar = false).getOrElse(plan)
65+
}
66+
}
67+
68+
private def applyForNonAQE(plan: SparkPlan): SparkPlan = {
69+
plan.transformUp {
70+
case exchange: ShuffleExchangeLike =>
71+
revertStageIfNeeded(exchange.child, exchange.supportsColumnar)
72+
.map(reverted => exchange.withNewChildren(Seq(reverted)))
73+
.getOrElse(exchange)
74+
}
75+
}
76+
77+
/** Reverts the stage if C2R count exceeds threshold. Wraps in R2C if exchange needs columnar. */
78+
private def revertStageIfNeeded(
79+
stagePlan: SparkPlan,
80+
outputColumnar: Boolean): Option[SparkPlan] = {
81+
val transitionCount = countTransitions(stagePlan)
82+
if (transitionCount <= maxTransitions) return None
83+
84+
logInfo(
85+
s"Reverting Comet native execution for stage with $transitionCount C2R transitions " +
86+
s"(threshold: $maxTransitions).")
87+
88+
val reverted = revertToSpark(stagePlan)
89+
val result = if (outputColumnar && !reverted.supportsColumnar) {
90+
RowToColumnarExec(reverted)
91+
} else {
92+
reverted
93+
}
94+
Some(result)
95+
}
96+
97+
98+
/** Counts C2R transitions within this stage, stopping at stage boundaries. */
99+
private[rules] def countTransitions(plan: SparkPlan): Int = {
100+
var count = 0
101+
def visit(node: SparkPlan): Unit = node match {
102+
case _: QueryStageExec | _: ShuffleExchangeLike => ()
103+
case _: ColumnarToRowTransition =>
104+
count += 1
105+
node.children.foreach(visit)
106+
case _ =>
107+
node.children.foreach(visit)
108+
}
109+
visit(plan)
110+
count
111+
}
112+
113+
// Two passes: strip transitions first (they assert child.supportsColumnar in constructors),
114+
// then revert Comet operators to row-based Spark equivalents.
115+
private[rules] def revertToSpark(plan: SparkPlan): SparkPlan = {
116+
val stripped = plan.transformDown {
117+
case CometNativeColumnarToRowExec(child) => child
118+
case CometColumnarToRowExec(child) => child
119+
case ColumnarToRowExec(child) => child
120+
case sparkToColumnar: CometSparkToColumnarExec => sparkToColumnar.child
121+
case RowToColumnarExec(child) => child
122+
}
123+
stripped.transformUp {
124+
case cometShuffle: CometShuffleExchangeExec =>
125+
cometShuffle.originalPlan.withNewChildren(Seq(cometShuffle.child))
126+
case cometExec: CometExec =>
127+
if (cometExec.originalPlan.children.size == cometExec.children.size) {
128+
cometExec.originalPlan.withNewChildren(cometExec.children)
129+
} else {
130+
cometExec.originalPlan
131+
}
132+
}
133+
}
134+
}

0 commit comments

Comments
 (0)