Skip to content

Commit 9061d20

Browse files
committed
Add suites to yml
1 parent 858c60b commit 9061d20

6 files changed

Lines changed: 38 additions & 30 deletions

File tree

.github/workflows/pr_build_linux.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ jobs:
321321
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
322322
org.apache.comet.rules.CometScanRuleSuite
323323
org.apache.comet.rules.CometExecRuleSuite
324+
org.apache.comet.rules.RevertNativeForTransitionHeavyStagesSuite
324325
org.apache.spark.sql.CometTPCDSQuerySuite
325326
org.apache.spark.sql.CometTPCDSQueryTestSuite
326327
org.apache.spark.sql.CometTPCHQuerySuite

.github/workflows/pr_build_macos.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ jobs:
163163
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
164164
org.apache.comet.rules.CometScanRuleSuite
165165
org.apache.comet.rules.CometExecRuleSuite
166+
org.apache.comet.rules.RevertNativeForTransitionHeavyStagesSuite
166167
org.apache.spark.sql.CometTPCDSQuerySuite
167168
org.apache.spark.sql.CometTPCDSQueryTestSuite
168169
org.apache.spark.sql.CometTPCHQuerySuite

spark/src/main/scala/org/apache/comet/CometConf.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,10 @@ object CometConf extends ShimCometConf {
465465
"shuffle still requires at least one R2C at the shuffle boundary. " +
466466
"Only effective when spark.comet.exec.transitionRevert.enabled is true.")
467467
.intConf
468-
.checkValue(_ >= 2, "Must be >= 2. A reverted stage still requires at least one " +
469-
"R2C at the columnar shuffle boundary.")
468+
.checkValue(
469+
_ >= 2,
470+
"Must be >= 2. A reverted stage still requires at least one " +
471+
"R2C at the columnar shuffle boundary.")
470472
.createWithDefault(2)
471473

472474
val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] =

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,8 @@ class CometSparkSessionExtensions
107107
override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session)
108108

109109
override def postColumnarTransitions: Rule[SparkPlan] = {
110-
val rules = Seq(
111-
EliminateRedundantTransitions(session),
112-
RevertNativeForTransitionHeavyStages(session))
110+
val rules =
111+
Seq(EliminateRedundantTransitions(session), RevertNativeForTransitionHeavyStages(session))
113112
plan => rules.foldLeft(plan) { case (p, rule) => rule(p) }
114113
}
115114
}

spark/src/main/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStages.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ import org.apache.comet.CometConf
3434
* Reverts a query stage to Spark row-based execution when it has too many columnar-to-row (C2R)
3535
* transitions. Each C2R indicates Comet could not keep execution columnar and had to fall back.
3636
* With columnar shuffle enabled, each C2R implies a corresponding R2C round-trip.
37-
*
3837
*/
3938
case class RevertNativeForTransitionHeavyStages(session: SparkSession)
4039
extends Rule[SparkPlan]
@@ -66,15 +65,16 @@ case class RevertNativeForTransitionHeavyStages(session: SparkSession)
6665
}
6766

6867
private def applyForNonAQE(plan: SparkPlan): SparkPlan = {
69-
plan.transformUp {
70-
case exchange: ShuffleExchangeLike =>
71-
revertStageIfNeeded(exchange.child, exchange.supportsColumnar)
72-
.map(reverted => exchange.withNewChildren(Seq(reverted)))
73-
.getOrElse(exchange)
68+
plan.transformUp { case exchange: ShuffleExchangeLike =>
69+
revertStageIfNeeded(exchange.child, exchange.supportsColumnar)
70+
.map(reverted => exchange.withNewChildren(Seq(reverted)))
71+
.getOrElse(exchange)
7472
}
7573
}
7674

77-
/** Reverts the stage if C2R count exceeds threshold. Wraps in R2C if exchange needs columnar. */
75+
/**
76+
* Reverts the stage if C2R count exceeds threshold. Wraps in R2C if exchange needs columnar.
77+
*/
7878
private def revertStageIfNeeded(
7979
stagePlan: SparkPlan,
8080
outputColumnar: Boolean): Option[SparkPlan] = {
@@ -94,7 +94,6 @@ case class RevertNativeForTransitionHeavyStages(session: SparkSession)
9494
Some(result)
9595
}
9696

97-
9897
/** Counts C2R transitions within this stage, stopping at stage boundaries. */
9998
private[rules] def countTransitions(plan: SparkPlan): Int = {
10099
var count = 0

spark/src/test/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStagesSuite.scala

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020
package org.apache.comet.rules
2121

22+
import org.apache.spark.sql.CometTestBase
2223
import org.apache.spark.sql.comet._
2324
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
2425
import org.apache.spark.sql.execution._
2526
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
2627

2728
import org.apache.comet.CometConf
28-
import org.apache.spark.sql.CometTestBase
2929

3030
class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
3131

@@ -77,7 +77,8 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
7777
CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.key -> "10") {
7878
withTempView("test_data") {
7979
spark.range(10).toDF("id").createOrReplaceTempView("test_data")
80-
val sparkPlan = createSparkPlan("SELECT id, id * 2 as doubled FROM test_data WHERE id > 5")
80+
val sparkPlan =
81+
createSparkPlan("SELECT id, id * 2 as doubled FROM test_data WHERE id > 5")
8182
val cometPlan = applyCometExecRule(sparkPlan)
8283

8384
val rule = RevertNativeForTransitionHeavyStages(spark)
@@ -141,7 +142,8 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
141142
val rule = RevertNativeForTransitionHeavyStages(spark)
142143
val reverted = rule.revertToSpark(cometPlan)
143144

144-
assert(countCometExecs(reverted) == 0,
145+
assert(
146+
countCometExecs(reverted) == 0,
145147
s"Should have no CometExec nodes after revert, plan:\n${reverted.treeString}")
146148
}
147149
}
@@ -162,7 +164,8 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
162164
val reverted = rule.revertToSpark(cometPlan)
163165

164166
// Reverted plan should have same output schema
165-
assert(reverted.output.map(_.name) == cometPlan.output.map(_.name),
167+
assert(
168+
reverted.output.map(_.name) == cometPlan.output.map(_.name),
166169
"Output schema should be preserved after revert")
167170
}
168171
}
@@ -183,7 +186,8 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
183186

184187
val rule = RevertNativeForTransitionHeavyStages(spark)
185188
val result = rule.revertToSpark(cometPlan)
186-
assert(countCometExecs(result) == 0,
189+
assert(
190+
countCometExecs(result) == 0,
187191
s"All CometExec should be reverted. Plan:\n${result.treeString}")
188192
}
189193
}
@@ -205,14 +209,14 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
205209
if (cometShuffles.nonEmpty) {
206210
val rule = RevertNativeForTransitionHeavyStages(spark)
207211
val reverted = rule.revertToSpark(cometPlan)
208-
val remainingCometShuffles = reverted.collect {
209-
case s: CometShuffleExchangeExec => s
212+
val remainingCometShuffles = reverted.collect { case s: CometShuffleExchangeExec =>
213+
s
210214
}
211-
assert(remainingCometShuffles.isEmpty,
215+
assert(
216+
remainingCometShuffles.isEmpty,
212217
"CometShuffleExchangeExec should be reverted to ShuffleExchangeExec")
213218
val sparkShuffles = reverted.collect { case s: ShuffleExchangeExec => s }
214-
assert(sparkShuffles.nonEmpty,
215-
"Should have ShuffleExchangeExec after revert")
219+
assert(sparkShuffles.nonEmpty, "Should have ShuffleExchangeExec after revert")
216220
}
217221
}
218222
}
@@ -228,17 +232,17 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
228232
"spark.sql.adaptive.enabled" -> "false") {
229233

230234
withTempView("test_data") {
231-
spark.range(10).selectExpr("id", "id % 3 as grp")
235+
spark
236+
.range(10)
237+
.selectExpr("id", "id % 3 as grp")
232238
.createOrReplaceTempView("test_data")
233-
val sparkPlan = createSparkPlan(
234-
"SELECT grp, count(*) FROM test_data GROUP BY grp")
239+
val sparkPlan = createSparkPlan("SELECT grp, count(*) FROM test_data GROUP BY grp")
235240
val cometPlan = applyCometExecRule(sparkPlan)
236241

237242
// With high threshold, the non-AQE path should not revert anything
238243
val rule = RevertNativeForTransitionHeavyStages(spark)
239244
val result = rule.apply(cometPlan)
240-
assert(result eq cometPlan,
241-
"Non-AQE path should not revert when below threshold")
245+
assert(result eq cometPlan, "Non-AQE path should not revert when below threshold")
242246
}
243247
}
244248
}
@@ -259,10 +263,12 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
259263
val pairs = rule.countTransitions(cometPlan)
260264
val result = rule.apply(cometPlan)
261265
if (pairs <= 2) {
262-
assert(result eq cometPlan,
266+
assert(
267+
result eq cometPlan,
263268
s"Plan with $pairs pairs should NOT be reverted at threshold 2")
264269
} else {
265-
assert(countCometExecs(result) == 0,
270+
assert(
271+
countCometExecs(result) == 0,
266272
s"Plan with $pairs pairs should be reverted at threshold 2")
267273
}
268274
}

0 commit comments

Comments
 (0)