@@ -191,6 +191,13 @@ case class CometExecRule(session: SparkSession)
191191
192192 private def isCometNative (op : SparkPlan ): Boolean = op.isInstanceOf [CometNativeExec ]
193193
194+ private def producesArrowOutput (plan : SparkPlan ): Boolean = plan match {
195+ case _ : CometNativeExec => true
196+ case u : CometUnionExec => u.children.forall(producesArrowOutput)
197+ case c : CometCoalesceExec => producesArrowOutput(c.child)
198+ case _ => false
199+ }
200+
194201 // spotless:off
195202
196203 /**
@@ -670,17 +677,17 @@ case class CometExecRule(session: SparkSession)
670677 private def convertToComet (op : SparkPlan , handler : CometOperatorSerde [_]): Option [SparkPlan ] = {
671678 val serde = handler.asInstanceOf [CometOperatorSerde [SparkPlan ]]
672679 if (isOperatorEnabled(serde, op)) {
673- // For operators that require native children (like writes), check if all data-producing
674- // children are CometExec (which includes CometNativeExec and sink operators like
675- // CometUnionExec, CometCoalesceExec, etc.). This prevents runtime failures when the
676- // native operator expects Arrow arrays but receives non-Arrow data.
680+ // Operators with requiresNativeChildren (like the native parquet writer) consume Arrow
681+ // batches from the JNI plan. Only CometNativeExec and pass-through sinks that forward
682+ // such batches unchanged (CometUnionExec, CometCoalesceExec) are safe; other CometExec
683+ // subclasses (CometLocalTableScanExec, CometCollectLimitExec, CometTakeOrderedAndProjectExec)
684+ // produce row-format ColumnarBatches and would crash the native operator at runtime.
677685 if (serde.requiresNativeChildren && op.children.nonEmpty) {
678- // Get the actual data-producing children (unwrap WriteFilesExec if present)
679686 val dataProducingChildren = op.children.flatMap {
680687 case writeFiles : WriteFilesExec => Seq (writeFiles.child)
681688 case other => Seq (other)
682689 }
683- if (! dataProducingChildren.forall(_. isInstanceOf [ CometExec ] )) {
690+ if (! dataProducingChildren.forall(producesArrowOutput )) {
684691 withInfo(op, " Cannot perform native operation because input is not in Arrow format" )
685692 return None
686693 }
0 commit comments