@@ -109,8 +109,11 @@ case class RevertNativeForTransitionHeavyStages(session: SparkSession)
109109 count
110110 }
111111
112- // Two passes: strip transitions first (they assert child.supportsColumnar in constructors),
113- // then revert Comet operators to row-based Spark equivalents.
112+ // Three passes:
113+ // 1. Strip existing transitions (they assert child.supportsColumnar in constructors)
114+ // 2. Revert Comet operators to row-based Spark equivalents
115+ // 3. Re-insert ColumnarToRowExec where a columnar child feeds a row-based parent
116+ // (e.g. QueryStageExec from a prior CometShuffleExchangeExec stage)
114117 private [rules] def revertToSpark (plan : SparkPlan ): SparkPlan = {
115118 val stripped = plan.transformDown {
116119 case CometNativeColumnarToRowExec (child) => child
@@ -119,7 +122,7 @@ case class RevertNativeForTransitionHeavyStages(session: SparkSession)
119122 case sparkToColumnar : CometSparkToColumnarExec => sparkToColumnar.child
120123 case RowToColumnarExec (child) => child
121124 }
122- stripped.transformUp {
125+ val reverted = stripped.transformUp {
123126 case cometShuffle : CometShuffleExchangeExec =>
124127 cometShuffle.originalPlan.withNewChildren(Seq (cometShuffle.child))
125128 case cometExec : CometExec =>
@@ -129,5 +132,16 @@ case class RevertNativeForTransitionHeavyStages(session: SparkSession)
129132 cometExec.originalPlan
130133 }
131134 }
135+ insertTransitions(reverted)
136+ }
137+
138+ private def insertTransitions (plan : SparkPlan ): SparkPlan = {
139+ plan.transformUp {
140+ case node if ! node.isInstanceOf [QueryStageExec ] && ! node.supportsColumnar =>
141+ val newChildren = node.children.map { child =>
142+ if (child.supportsColumnar) ColumnarToRowExec (child) else child
143+ }
144+ if (newChildren != node.children) node.withNewChildren(newChildren) else node
145+ }
132146 }
133147}
0 commit comments