Skip to content

Commit 3b51521

Browse files
committed
Fix tests
1 parent 3ed99e4 commit 3b51521

1 file changed

Lines changed: 17 additions & 3 deletions

File tree

spark/src/main/scala/org/apache/comet/rules/RevertNativeForTransitionHeavyStages.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,11 @@ case class RevertNativeForTransitionHeavyStages(session: SparkSession)
109109
count
110110
}
111111

112-
// Two passes: strip transitions first (they assert child.supportsColumnar in constructors),
113-
// then revert Comet operators to row-based Spark equivalents.
112+
// Three passes:
113+
// 1. Strip existing transitions (they assert child.supportsColumnar in constructors)
114+
// 2. Revert Comet operators to row-based Spark equivalents
115+
// 3. Re-insert ColumnarToRowExec where a columnar child feeds a row-based parent
116+
// (e.g. QueryStageExec from a prior CometShuffleExchangeExec stage)
114117
private[rules] def revertToSpark(plan: SparkPlan): SparkPlan = {
115118
val stripped = plan.transformDown {
116119
case CometNativeColumnarToRowExec(child) => child
@@ -119,7 +122,7 @@ case class RevertNativeForTransitionHeavyStages(session: SparkSession)
119122
case sparkToColumnar: CometSparkToColumnarExec => sparkToColumnar.child
120123
case RowToColumnarExec(child) => child
121124
}
122-
stripped.transformUp {
125+
val reverted = stripped.transformUp {
123126
case cometShuffle: CometShuffleExchangeExec =>
124127
cometShuffle.originalPlan.withNewChildren(Seq(cometShuffle.child))
125128
case cometExec: CometExec =>
@@ -129,5 +132,16 @@ case class RevertNativeForTransitionHeavyStages(session: SparkSession)
129132
cometExec.originalPlan
130133
}
131134
}
135+
insertTransitions(reverted)
136+
}
137+
138+
private def insertTransitions(plan: SparkPlan): SparkPlan = {
139+
plan.transformUp {
140+
case node if !node.isInstanceOf[QueryStageExec] && !node.supportsColumnar =>
141+
val newChildren = node.children.map { child =>
142+
if (child.supportsColumnar) ColumnarToRowExec(child) else child
143+
}
144+
if (newChildren != node.children) node.withNewChildren(newChildren) else node
145+
}
132146
}
133147
}

0 commit comments

Comments
 (0)