Skip to content

Commit 406bb2f

Browse files
committed
Whitelist Arrow-producing CometExec subclasses for native writer
1 parent e2ebc26 commit 406bb2f

1 file changed

Lines changed: 13 additions & 6 deletions

File tree

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,13 @@ case class CometExecRule(session: SparkSession)
191191

192192
private def isCometNative(op: SparkPlan): Boolean = op.isInstanceOf[CometNativeExec]
193193

194+
private def producesArrowOutput(plan: SparkPlan): Boolean = plan match {
195+
case _: CometNativeExec => true
196+
case u: CometUnionExec => u.children.forall(producesArrowOutput)
197+
case c: CometCoalesceExec => producesArrowOutput(c.child)
198+
case _ => false
199+
}
200+
194201
// spotless:off
195202

196203
/**
@@ -670,17 +677,17 @@ case class CometExecRule(session: SparkSession)
670677
private def convertToComet(op: SparkPlan, handler: CometOperatorSerde[_]): Option[SparkPlan] = {
671678
val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
672679
if (isOperatorEnabled(serde, op)) {
673-
// For operators that require native children (like writes), check if all data-producing
674-
// children are CometExec (which includes CometNativeExec and sink operators like
675-
// CometUnionExec, CometCoalesceExec, etc.). This prevents runtime failures when the
676-
// native operator expects Arrow arrays but receives non-Arrow data.
680+
// Operators with requiresNativeChildren (like the native parquet writer) consume Arrow
681+
// batches from the JNI plan. Only CometNativeExec and pass-through sinks that forward
682+
// such batches unchanged (CometUnionExec, CometCoalesceExec) are safe; other CometExec
683+
// subclasses (CometLocalTableScanExec, CometCollectLimitExec, CometTakeOrderedAndProjectExec)
684+
// produce row-format ColumnarBatches and would crash the native operator at runtime.
677685
if (serde.requiresNativeChildren && op.children.nonEmpty) {
678-
// Get the actual data-producing children (unwrap WriteFilesExec if present)
679686
val dataProducingChildren = op.children.flatMap {
680687
case writeFiles: WriteFilesExec => Seq(writeFiles.child)
681688
case other => Seq(other)
682689
}
683-
if (!dataProducingChildren.forall(_.isInstanceOf[CometExec])) {
690+
if (!dataProducingChildren.forall(producesArrowOutput)) {
684691
withInfo(op, "Cannot perform native operation because input is not in Arrow format")
685692
return None
686693
}

0 commit comments

Comments
 (0)