diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 422232f546..15a3a6752f 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -321,6 +321,7 @@ jobs: org.apache.spark.CometPluginsUnifiedModeOverrideSuite org.apache.comet.rules.CometScanRuleSuite org.apache.comet.rules.CometExecRuleSuite + org.apache.comet.rules.RevertNativeForTransitionHeavyStagesSuite org.apache.spark.sql.CometTPCDSQuerySuite org.apache.spark.sql.CometTPCDSQueryTestSuite org.apache.spark.sql.CometTPCHQuerySuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index d0a03eeb75..b95d69e176 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -163,6 +163,7 @@ jobs: org.apache.spark.CometPluginsUnifiedModeOverrideSuite org.apache.comet.rules.CometScanRuleSuite org.apache.comet.rules.CometExecRuleSuite + org.apache.comet.rules.RevertNativeForTransitionHeavyStagesSuite org.apache.spark.sql.CometTPCDSQuerySuite org.apache.spark.sql.CometTPCDSQueryTestSuite org.apache.spark.sql.CometTPCHQuerySuite diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 78ea0f0168..fe71793fb9 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -442,6 +442,31 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_EXEC_TRANSITION_REVERT_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.transitionRevert.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, Comet reverts a query stage to Spark row-based execution if the number " + + "of columnar-to-row and row-to-columnar transition pairs exceeds the configured " + + "threshold. This avoids the overhead of repeated format conversions in stages where " + + "many operators fall back to row-based execution.") + .booleanConf + .createWithDefault(true) + + val COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.transitionRevert.maxTransitions") + .category(CATEGORY_EXEC) + .doc( + "The maximum number of columnar-to-row (C2R) transitions allowed in a single query " + + "stage before Comet reverts the entire stage to Spark row-based execution. When " + + "columnar shuffle is enabled, each C2R has a corresponding row-to-columnar (R2C) " + + "conversion to feed back into the columnar shuffle, so the count reflects full " + + "round-trips. Set to 0 to revert any stage with transitions. " + + "Only effective when spark.comet.exec.transitionRevert.enabled is true.") + .intConf + .checkValue(_ >= 0, "Must be >= 0.") + .createWithDefault(5) + val COMET_EXEC_SHUFFLE_COMPRESSION_CODEC: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.compression.codec") .category(CATEGORY_SHUFFLE) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 6c4a92f312..cbf281864a 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf._ -import org.apache.comet.rules.{CometExecRule, CometPlanAdaptiveDynamicPruningFilters, CometReuseSubquery, CometScanRule, CometSpark34AqeDppFallbackRule, EliminateRedundantTransitions} +import org.apache.comet.rules.{CometExecRule, CometPlanAdaptiveDynamicPruningFilters, CometReuseSubquery, CometScanRule, CometSpark34AqeDppFallbackRule, EliminateRedundantTransitions, RevertNativeForTransitionHeavyStages} import org.apache.comet.shims.ShimCometSparkSessionExtensions /** @@ -106,8 +106,11 @@ class CometSparkSessionExtensions case class CometExecColumnar(session: SparkSession) extends ColumnarRule { override def preColumnarTransitions: Rule[SparkPlan] = CometExecRule(session) - override def postColumnarTransitions: Rule[SparkPlan] = - EliminateRedundantTransitions(session) + override def postColumnarTransitions: Rule[SparkPlan] = { + val rules = + Seq(EliminateRedundantTransitions(session), RevertNativeForTransitionHeavyStages(session)) + plan => rules.foldLeft(plan) { case (p, rule) => rule(p) } + } } } diff --git a/spark/src/main/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStages.scala b/spark/src/main/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStages.scala new file mode 100644 index 0000000000..8fb435b0f3 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStages.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.rules + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometExec, CometNativeColumnarToRowExec, CometSparkToColumnarExec} +import org.apache.spark.sql.execution.{ColumnarToRowExec, ColumnarToRowTransition, RowToColumnarExec, SparkPlan} +import org.apache.spark.sql.execution.adaptive.QueryStageExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeLike} + +import org.apache.comet.CometConf + +/** + * Reverts a query stage to Spark row-based execution when it has too many columnar-to-row (C2R) + * transitions. Each C2R indicates Comet could not keep execution columnar and had to fall back. + * With columnar shuffle enabled, each C2R implies a corresponding R2C round-trip. + */ +case class RevertNativeForTransitionHeavyStages(session: SparkSession) + extends Rule[SparkPlan] + with Logging { + + private lazy val enabled = CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.get() + private lazy val maxTransitions = CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.get() + + override def apply(plan: SparkPlan): SparkPlan = { + if (!enabled) return plan + + if (session.sessionState.conf.adaptiveExecutionEnabled) { + applyForAQE(plan) + } else { + applyForNonAQE(plan) + } + } + + private def applyForAQE(plan: SparkPlan): SparkPlan = { + plan match { + case _: BroadcastExchangeLike => plan + case exchange: ShuffleExchangeLike => + revertStageIfNeeded(exchange.child, exchange.supportsColumnar) + .map(reverted => exchange.withNewChildren(Seq(reverted))) + .getOrElse(plan) + case _ => + revertStageIfNeeded(plan, outputColumnar = false).getOrElse(plan) + } + } + + private def applyForNonAQE(plan: SparkPlan): SparkPlan = { + val withRevertedStages = plan.transformUp { case exchange: ShuffleExchangeLike => + revertStageIfNeeded(exchange.child, exchange.supportsColumnar) + .map(reverted => exchange.withNewChildren(Seq(reverted))) + .getOrElse(exchange) + } + revertStageIfNeeded(withRevertedStages, outputColumnar = false) + .getOrElse(withRevertedStages) + } + + /** + * Reverts the stage if C2R count exceeds threshold. Wraps in R2C if exchange needs columnar. + */ + private def revertStageIfNeeded( + stagePlan: SparkPlan, + outputColumnar: Boolean): Option[SparkPlan] = { + val transitionCount = countTransitions(stagePlan) + if (transitionCount <= maxTransitions) return None + + logInfo( + s"Reverting Comet native execution for stage with $transitionCount C2R transitions " + + s"(threshold: $maxTransitions).") + + val reverted = revertToSpark(stagePlan) + val result = if (outputColumnar && !reverted.supportsColumnar) { + RowToColumnarExec(reverted) + } else { + reverted + } + Some(result) + } + + /** Counts C2R transitions within this stage, stopping at stage boundaries. */ + private[rules] def countTransitions(plan: SparkPlan): Int = { + var count = 0 + def visit(node: SparkPlan): Unit = node match { + case _: QueryStageExec | _: ShuffleExchangeLike | _: BroadcastExchangeLike => () + case _: ColumnarToRowTransition => + count += 1 + node.children.foreach(visit) + case _ => + node.children.foreach(visit) + } + visit(plan) + count + } + + private[rules] def revertToSpark(plan: SparkPlan): SparkPlan = { + val stripped = plan.transformDown { + case CometNativeColumnarToRowExec(child) => child + case CometColumnarToRowExec(child) => child + case ColumnarToRowExec(child) => child + case sparkToColumnar: CometSparkToColumnarExec => sparkToColumnar.child + case RowToColumnarExec(child) => child + } + val reverted = stripped.transformUp { case cometExec: CometExec => + if (cometExec.originalPlan.children.size == cometExec.children.size) { + cometExec.originalPlan.withNewChildren(cometExec.children) + } else { + logWarning( + s"Comet plan and original have different child count for " + + s"${cometExec.getClass.getSimpleName}, using originalPlan as-is.") + cometExec.originalPlan + } + } + insertTransitions(reverted) + } + + private def insertTransitions(plan: SparkPlan): SparkPlan = { + plan.transformUp { + case node if !node.isInstanceOf[QueryStageExec] && !node.supportsColumnar => + val newChildren = node.children.map { child => + if (child.supportsColumnar) ColumnarToRowExec(child) else child + } + if (newChildren != node.children) node.withNewChildren(newChildren) else node + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStagesSuite.scala b/spark/src/test/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStagesSuite.scala new file mode 100644 index 0000000000..7940dd9f79 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStagesSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.rules + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet._ +import org.apache.spark.sql.execution._ + +import org.apache.comet.CometConf + +class RevertNativeForTransitionHeavyStagesSuite extends CometTestBase { + + private def createSparkPlan(sql: String): SparkPlan = { + var plan: SparkPlan = null + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + plan = spark.sql(sql).queryExecution.executedPlan + } + stripAQEPlan(plan) + } + + private def applyCometExecRule(plan: SparkPlan): SparkPlan = { + CometExecRule(spark).apply(plan) + } + + private def applyFullColumnarPipeline(plan: SparkPlan): SparkPlan = { + val cometPlan = CometScanRule(spark).apply(plan) + val execPlan = CometExecRule(spark).apply(cometPlan) + val withTransitions = ApplyColumnarRulesAndInsertTransitions(Seq.empty, false).apply(execPlan) + EliminateRedundantTransitions(spark).apply(withTransitions) + } + + private def countCometExecs(plan: SparkPlan): Int = { + plan.collect { case _: CometExec => true }.size + } + + private def countC2RNodes(plan: SparkPlan): Int = { + plan.collect { case _: ColumnarToRowTransition => true }.size + } + + test("rule is a no-op when disabled") { + withSQLConf(CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "false") { + withTempView("test_data") { + spark.range(10).toDF("id").createOrReplaceTempView("test_data") + val sparkPlan = createSparkPlan("SELECT id, id * 2 FROM test_data WHERE id > 5") + val cometPlan = applyCometExecRule(sparkPlan) + assert(countCometExecs(cometPlan) > 0, "Plan should have CometExec nodes") + + val rule = RevertNativeForTransitionHeavyStages(spark) + val result = rule.apply(cometPlan) + assert(result eq cometPlan, "Rule should be a no-op when disabled") + } + } + } + + test("rule does not revert plan below threshold") { + withSQLConf( + CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.key -> "10", + "spark.comet.exec.project.enabled" -> "false") { + withTempView("test_data") { + spark.range(10).toDF("id").createOrReplaceTempView("test_data") + val sparkPlan = + createSparkPlan("SELECT id, id * 2 as doubled FROM test_data WHERE id > 5") + val cometPlan = applyFullColumnarPipeline(sparkPlan) + + val rule = RevertNativeForTransitionHeavyStages(spark) + val transitions = rule.countTransitions(cometPlan) + assert(transitions > 0, s"Plan should have transitions, got $transitions") + assert(transitions <= 10, s"Transitions should be below threshold") + + val result = rule.apply(cometPlan) + assert(result eq cometPlan, "Plan should be unchanged when below threshold") + } + } + } + + test("revertToSpark preserves plan structure") { + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + + withTempView("test_data") { + spark.range(10).toDF("id").createOrReplaceTempView("test_data") + val sparkPlan = + createSparkPlan("SELECT id, id * 2 as doubled FROM test_data WHERE id > 5") + val cometPlan = applyCometExecRule(sparkPlan) + val rule = RevertNativeForTransitionHeavyStages(spark) + val reverted = rule.revertToSpark(cometPlan) + + // Reverted plan should have same output schema + assert( + reverted.output.map(_.name) == cometPlan.output.map(_.name), + "Output schema should be preserved after revert") + } + } + } + + test("revertToSpark removes all Comet operators from a plan with transitions") { + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + + withTempView("test_data") { + spark.range(10).toDF("id").createOrReplaceTempView("test_data") + val sparkPlan = + createSparkPlan("SELECT id, id * 2 as doubled FROM test_data WHERE id > 5") + val cometPlan = applyFullColumnarPipeline(sparkPlan) + assert(countCometExecs(cometPlan) > 0, "Should have CometExec nodes before revert") + + val rule = RevertNativeForTransitionHeavyStages(spark) + val result = rule.revertToSpark(cometPlan) + assert( + countCometExecs(result) == 0, + s"All CometExec should be reverted. Plan:\n${result.treeString}") + } + } + } + + test("non-AQE path applies rule per-stage via transformUp") { + withSQLConf( + CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.key -> "10", + "spark.sql.adaptive.enabled" -> "false") { + + withTempView("test_data") { + spark + .range(10) + .selectExpr("id", "id % 3 as grp") + .createOrReplaceTempView("test_data") + val sparkPlan = createSparkPlan("SELECT grp, count(*) FROM test_data GROUP BY grp") + val cometPlan = applyCometExecRule(sparkPlan) + + // With high threshold, the non-AQE path should not revert anything + val rule = RevertNativeForTransitionHeavyStages(spark) + val result = rule.apply(cometPlan) + assert(result eq cometPlan, "Non-AQE path should not revert when below threshold") + } + } + } + + test("revert fires and produces correct results when transitions exceed threshold") { + withParquetTable((0 until 100).map(i => (i, i % 10, s"val_$i")), "tbl") { + val query = "SELECT _2, count(*), sum(_1) FROM tbl GROUP BY _2" + + // Without revert, plan should have CometExec nodes with transitions + withSQLConf( + CometConf.COMET_EXEC_TRANSITION_REVERT_ENABLED.key -> "false", + "spark.comet.exec.project.enabled" -> "false") { + val df = sql(query) + df.collect() + val plan = stripAQEPlan(df.queryExecution.executedPlan) + assert(countCometExecs(plan) > 0, "Plan without revert should have CometExec nodes") + assert(countC2RNodes(plan) > 0, "Plan without revert should have C2R transitions") + } + + // With revert enabled at threshold 0, all CometExec should be removed + withSQLConf( + CometConf.COMET_EXEC_TRANSITION_REVERT_MAX_TRANSITIONS.key -> "0", + "spark.comet.exec.project.enabled" -> "false") { + val (_, cometPlan) = checkSparkAnswer(query) + val executedPlan = stripAQEPlan(cometPlan) + assert( + countCometExecs(executedPlan) == 0, + s"Revert should have removed all CometExec nodes:\n${executedPlan.treeString}") + } + } + } + +}