1919
2020package org .apache .comet .rules
2121
22+ import org .apache .spark .sql .CometTestBase
2223import org .apache .spark .sql .comet ._
2324import org .apache .spark .sql .comet .execution .shuffle .CometShuffleExchangeExec
2425import org .apache .spark .sql .execution ._
2526import org .apache .spark .sql .execution .exchange .ShuffleExchangeExec
2627
2728import org .apache .comet .CometConf
28- import org .apache .spark .sql .CometTestBase
2929
3030class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
3131
@@ -77,7 +77,8 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
7777 CometConf .COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS .key -> " 10" ) {
7878 withTempView(" test_data" ) {
7979 spark.range(10 ).toDF(" id" ).createOrReplaceTempView(" test_data" )
80- val sparkPlan = createSparkPlan(" SELECT id, id * 2 as doubled FROM test_data WHERE id > 5" )
80+ val sparkPlan =
81+ createSparkPlan(" SELECT id, id * 2 as doubled FROM test_data WHERE id > 5" )
8182 val cometPlan = applyCometExecRule(sparkPlan)
8283
8384 val rule = RevertNativeForTransitionHeavyStages (spark)
@@ -141,7 +142,8 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
141142 val rule = RevertNativeForTransitionHeavyStages (spark)
142143 val reverted = rule.revertToSpark(cometPlan)
143144
144- assert(countCometExecs(reverted) == 0 ,
145+ assert(
146+ countCometExecs(reverted) == 0 ,
145147 s " Should have no CometExec nodes after revert, plan: \n ${reverted.treeString}" )
146148 }
147149 }
@@ -162,7 +164,8 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
162164 val reverted = rule.revertToSpark(cometPlan)
163165
164166 // Reverted plan should have same output schema
165- assert(reverted.output.map(_.name) == cometPlan.output.map(_.name),
167+ assert(
168+ reverted.output.map(_.name) == cometPlan.output.map(_.name),
166169 " Output schema should be preserved after revert" )
167170 }
168171 }
@@ -183,7 +186,8 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
183186
184187 val rule = RevertNativeForTransitionHeavyStages (spark)
185188 val result = rule.revertToSpark(cometPlan)
186- assert(countCometExecs(result) == 0 ,
189+ assert(
190+ countCometExecs(result) == 0 ,
187191 s " All CometExec should be reverted. Plan: \n ${result.treeString}" )
188192 }
189193 }
@@ -205,14 +209,14 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
205209 if (cometShuffles.nonEmpty) {
206210 val rule = RevertNativeForTransitionHeavyStages (spark)
207211 val reverted = rule.revertToSpark(cometPlan)
208- val remainingCometShuffles = reverted.collect {
209- case s : CometShuffleExchangeExec => s
212+ val remainingCometShuffles = reverted.collect { case s : CometShuffleExchangeExec =>
213+ s
210214 }
211- assert(remainingCometShuffles.isEmpty,
215+ assert(
216+ remainingCometShuffles.isEmpty,
212217 " CometShuffleExchangeExec should be reverted to ShuffleExchangeExec" )
213218 val sparkShuffles = reverted.collect { case s : ShuffleExchangeExec => s }
214- assert(sparkShuffles.nonEmpty,
215- " Should have ShuffleExchangeExec after revert" )
219+ assert(sparkShuffles.nonEmpty, " Should have ShuffleExchangeExec after revert" )
216220 }
217221 }
218222 }
@@ -228,17 +232,17 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
228232 " spark.sql.adaptive.enabled" -> " false" ) {
229233
230234 withTempView(" test_data" ) {
231- spark.range(10 ).selectExpr(" id" , " id % 3 as grp" )
235+ spark
236+ .range(10 )
237+ .selectExpr(" id" , " id % 3 as grp" )
232238 .createOrReplaceTempView(" test_data" )
233- val sparkPlan = createSparkPlan(
234- " SELECT grp, count(*) FROM test_data GROUP BY grp" )
239+ val sparkPlan = createSparkPlan(" SELECT grp, count(*) FROM test_data GROUP BY grp" )
235240 val cometPlan = applyCometExecRule(sparkPlan)
236241
237242 // With high threshold, the non-AQE path should not revert anything
238243 val rule = RevertNativeForTransitionHeavyStages (spark)
239244 val result = rule.apply(cometPlan)
240- assert(result eq cometPlan,
241- " Non-AQE path should not revert when below threshold" )
245+ assert(result eq cometPlan, " Non-AQE path should not revert when below threshold" )
242246 }
243247 }
244248 }
@@ -259,10 +263,12 @@ class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase {
259263 val pairs = rule.countTransitions(cometPlan)
260264 val result = rule.apply(cometPlan)
261265 if (pairs <= 2 ) {
262- assert(result eq cometPlan,
266+ assert(
267+ result eq cometPlan,
263268 s " Plan with $pairs pairs should NOT be reverted at threshold 2 " )
264269 } else {
265- assert(countCometExecs(result) == 0 ,
270+ assert(
271+ countCometExecs(result) == 0 ,
266272 s " Plan with $pairs pairs should be reverted at threshold 2 " )
267273 }
268274 }
0 commit comments