Skip to content

Commit ad17997

Browse files
authored
chore: Refactor some of the scan and sink handling in CometExecRule to reduce duplicate code (apache#2844)
1 parent fe49e40 commit ad17997

4 files changed

Lines changed: 122 additions & 186 deletions

File tree

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 89 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
package org.apache.comet.rules
2121

22-
import scala.jdk.CollectionConverters._
23-
2422
import org.apache.spark.sql.SparkSession
2523
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
2624
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
@@ -46,7 +44,6 @@ import org.apache.comet.CometSparkSessionExtensions._
4644
import org.apache.comet.rules.CometExecRule.allExecs
4745
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, QueryPlanSerde, Unsupported}
4846
import org.apache.comet.serde.OperatorOuterClass.Operator
49-
import org.apache.comet.serde.QueryPlanSerde.{serializeDataType, supportedDataType}
5047
import org.apache.comet.serde.operator._
5148
import org.apache.comet.serde.operator.CometDataWritingCommand
5249

@@ -71,13 +68,6 @@ object CometExecRule {
7168
classOf[LocalTableScanExec] -> CometLocalTableScanExec,
7269
classOf[WindowExec] -> CometWindowExec)
7370

74-
/**
75-
* DataWritingCommandExec is handled separately in convertNode since it doesn't follow the
76-
* standard pattern of having CometNativeExec children.
77-
*/
78-
val writeExecs: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] =
79-
Map(classOf[DataWritingCommandExec] -> CometDataWritingCommand)
80-
8171
/**
8272
* Sinks that have a native plan of ScanExec.
8373
*/
@@ -186,57 +176,33 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
186176
def convertNode(op: SparkPlan): SparkPlan = op match {
187177
// Fully native scan for V1
188178
case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION =>
189-
val nativeOp = operator2Proto(scan).get
190-
CometNativeScan.createExec(nativeOp, scan)
179+
convertToComet(scan, CometNativeScan).getOrElse(scan)
191180

192181
// Fully native Iceberg scan for V2 (iceberg-rust path)
193182
// Only handle scans with native metadata; SupportsComet scans fall through to isCometScan
194183
// Config checks (COMET_ICEBERG_NATIVE_ENABLED, COMET_EXEC_ENABLED) are done in CometScanRule
195184
case scan: CometBatchScanExec if scan.nativeIcebergScanMetadata.isDefined =>
196-
operator2Proto(scan) match {
197-
case Some(nativeOp) =>
198-
CometIcebergNativeScan.createExec(nativeOp, scan)
199-
case None =>
200-
// Serialization failed, fall back to CometBatchScanExec
201-
scan
202-
}
185+
convertToComet(scan, CometIcebergNativeScan).getOrElse(scan)
203186

204187
// Comet JVM + native scan for V1 and V2
205188
case op if isCometScan(op) =>
206-
val nativeOp = operator2Proto(op)
207-
CometScanWrapper(nativeOp.get, op)
189+
convertToComet(op, CometScanWrapper).getOrElse(op)
208190

209191
case op if shouldApplySparkToColumnar(conf, op) =>
210-
val cometOp = CometSparkToColumnarExec(op)
211-
val nativeOp = operator2Proto(cometOp)
212-
CometScanWrapper(nativeOp.get, cometOp)
213-
214-
// Handle DataWritingCommandExec specially since it doesn't follow the standard pattern
215-
case exec: DataWritingCommandExec =>
216-
CometExecRule.writeExecs.get(classOf[DataWritingCommandExec]) match {
217-
case Some(handler) if isOperatorEnabled(handler, exec) =>
218-
val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(exec.id)
219-
handler
220-
.asInstanceOf[CometOperatorSerde[DataWritingCommandExec]]
221-
.convert(exec, builder)
222-
.map(nativeOp =>
223-
handler
224-
.asInstanceOf[CometOperatorSerde[DataWritingCommandExec]]
225-
.createExec(nativeOp, exec))
226-
.getOrElse(exec)
227-
case _ =>
228-
exec
229-
}
192+
convertToComet(op, CometSparkToColumnarExec).getOrElse(op)
193+
194+
case op: DataWritingCommandExec =>
195+
convertToComet(op, CometDataWritingCommand).getOrElse(op)
230196

231197
// For AQE broadcast stage on a Comet broadcast exchange
232198
case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) =>
233-
newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
199+
convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
234200

235201
case s @ BroadcastQueryStageExec(
236202
_,
237203
ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
238204
_) =>
239-
newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
205+
convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
240206

241207
// `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast
242208
// exchange. It is only used for Comet native execution. We only transform Spark broadcast
@@ -273,37 +239,26 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
273239

274240
// For AQE shuffle stage on a Comet shuffle exchange
275241
case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) =>
276-
newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
242+
convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
277243

278244
// For AQE shuffle stage on a reused Comet shuffle exchange
279245
// Note that we don't need to handle `ReusedExchangeExec` for non-AQE case, because
280246
// the query plan won't be re-optimized/planned in non-AQE mode.
281247
case s @ ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) =>
282-
newPlanWithProto(s, CometSinkPlaceHolder(_, s, s))
248+
convertToCometIfAllChildrenAreNative(s, CometExchangeSink).getOrElse(s)
283249

284250
case s: ShuffleExchangeExec =>
285251
// try native shuffle first, then columnar shuffle, then fall back to Spark
286252
// if neither are supported
287253
tryNativeShuffle(s).orElse(tryColumnarShuffle(s)).getOrElse(s)
288254

289255
case op =>
290-
allExecs
256+
val handler = allExecs
291257
.get(op.getClass)
292-
.map(_.asInstanceOf[CometOperatorSerde[SparkPlan]]) match {
258+
.map(_.asInstanceOf[CometOperatorSerde[SparkPlan]])
259+
handler match {
293260
case Some(handler) =>
294-
if (op.children.forall(isCometNative)) {
295-
if (isOperatorEnabled(handler, op)) {
296-
val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
297-
val childOp = op.children.map(_.asInstanceOf[CometNativeExec].nativeOp)
298-
childOp.foreach(builder.addChildren)
299-
return handler
300-
.convert(op, builder, childOp: _*)
301-
.map(handler.createExec(_, op))
302-
.getOrElse(op)
303-
}
304-
} else {
305-
return op
306-
}
261+
return convertToCometIfAllChildrenAreNative(op, handler).getOrElse(op)
307262
case _ =>
308263
}
309264

@@ -332,25 +287,11 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
332287
}
333288
}
334289

335-
private def operator2ProtoIfAllChildrenAreNative(op: SparkPlan): Option[Operator] = {
336-
if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
337-
operator2Proto(op, op.children.map(_.asInstanceOf[CometNativeExec].nativeOp): _*)
338-
} else {
339-
None
340-
}
341-
}
342-
343-
/**
344-
* Convert operator to proto and then apply a transformation to wrap the proto in a new plan.
345-
*/
346-
private def newPlanWithProto(op: SparkPlan, fun: Operator => SparkPlan): SparkPlan = {
347-
operator2ProtoIfAllChildrenAreNative(op).map(fun).getOrElse(op)
348-
}
349-
350290
private def tryNativeShuffle(s: ShuffleExchangeExec): Option[SparkPlan] = {
351291
Some(s)
352-
.filter(_ => nativeShuffleSupported(s))
353-
.flatMap(_ => operator2ProtoIfAllChildrenAreNative(s))
292+
.filter(nativeShuffleSupported)
293+
.filter(_.children.forall(_.isInstanceOf[CometNativeExec]))
294+
.flatMap(_ => operator2Proto(s))
354295
.map { nativeOp =>
355296
// Switch to use Decimal128 regardless of precision, since Arrow native execution
356297
// doesn't support Decimal32 and Decimal64 yet.
@@ -366,7 +307,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
366307
// If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
367308
// convert it to CometColumnarShuffle,
368309
Some(s)
369-
.filter(_ => columnarShuffleSupported(s))
310+
.filter(columnarShuffleSupported)
370311
.flatMap(_ => operator2Proto(s))
371312
.flatMap { nativeOp =>
372313
s.child match {
@@ -819,84 +760,45 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
819760
}
820761

821762
/**
822-
* Convert a Spark plan operator to a protobuf Comet operator.
823-
*
824-
* @param op
825-
* Spark plan operator
826-
* @param childOp
827-
* previously converted protobuf Comet operators, which will be consumed by the Spark plan
828-
* operator as its children
829-
* @return
830-
* The converted Comet native operator for the input `op`, or `None` if the `op` cannot be
831-
* converted to a native operator.
763+
* Fallback for handling sinks that have not been handled explicitly. This method should
764+
* eventually be removed once CometExecRule fully uses the operator serde framework.
832765
*/
833766
private def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = {
767+
768+
def isCometSink(op: SparkPlan): Boolean = {
769+
op match {
770+
case _: CometSparkToColumnarExec => true
771+
case _: CometSinkPlaceHolder => true
772+
case _ => false
773+
}
774+
}
775+
776+
def isExchangeSink(op: SparkPlan): Boolean = {
777+
op match {
778+
case _: ShuffleExchangeExec => true
779+
case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
780+
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) =>
781+
true
782+
case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
783+
case BroadcastQueryStageExec(
784+
_,
785+
ReusedExchangeExec(_, _: CometBroadcastExchangeExec),
786+
_) =>
787+
true
788+
case _: BroadcastExchangeExec => true
789+
case _ => false
790+
}
791+
}
792+
834793
val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
835794
childOp.foreach(builder.addChildren)
836795

837796
op match {
838-
839-
// Fully native scan for V1
840-
case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_DATAFUSION =>
841-
CometNativeScan.convert(scan, builder, childOp: _*)
842-
843-
// Fully native Iceberg scan for V2 (iceberg-rust path)
844-
case scan: CometBatchScanExec if scan.nativeIcebergScanMetadata.isDefined =>
845-
CometIcebergNativeScan.convert(scan, builder, childOp: _*)
797+
case op if isExchangeSink(op) =>
798+
CometExchangeSink.convert(op, builder, childOp: _*)
846799

847800
case op if isCometSink(op) =>
848-
val supportedTypes =
849-
op.output.forall(a => supportedDataType(a.dataType, allowComplex = true))
850-
851-
if (!supportedTypes) {
852-
withInfo(op, "Unsupported data type")
853-
return None
854-
}
855-
856-
// These operators are source of Comet native execution chain
857-
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
858-
val source = op.simpleStringWithNodeId()
859-
if (source.isEmpty) {
860-
scanBuilder.setSource(op.getClass.getSimpleName)
861-
} else {
862-
scanBuilder.setSource(source)
863-
}
864-
865-
val ffiSafe = op match {
866-
case _ if isExchangeSink(op) =>
867-
// Source of broadcast exchange batches is ArrowStreamReader
868-
// Source of shuffle exchange batches is NativeBatchDecoderIterator
869-
true
870-
case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_COMET =>
871-
// native_comet scan reuses mutable buffers
872-
false
873-
case scan: CometScanExec if scan.scanImpl == CometConf.SCAN_NATIVE_ICEBERG_COMPAT =>
874-
// native_iceberg_compat scan reuses mutable buffers for constant columns
875-
// https://github.com/apache/datafusion-comet/issues/2152
876-
false
877-
case _ =>
878-
false
879-
}
880-
scanBuilder.setArrowFfiSafe(ffiSafe)
881-
882-
val scanTypes = op.output.flatten { attr =>
883-
serializeDataType(attr.dataType)
884-
}
885-
886-
if (scanTypes.length == op.output.length) {
887-
scanBuilder.addAllFields(scanTypes.asJava)
888-
889-
// Sink operators don't have children
890-
builder.clearChildren()
891-
892-
Some(builder.setScan(scanBuilder).build())
893-
} else {
894-
// There are unsupported scan type
895-
withInfo(
896-
op,
897-
s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above")
898-
None
899-
}
801+
CometScanWrapper.convert(op, builder, childOp: _*)
900802

901803
case _ =>
902804
// Emit warning if:
@@ -910,12 +812,46 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
910812
}
911813
}
912814

913-
private def isOperatorEnabled(handler: CometOperatorSerde[_], op: SparkPlan): Boolean = {
914-
val enabled = handler.enabledConfig.forall(_.get(op.conf))
815+
/**
816+
* Convert a Spark plan to a Comet plan using the specified serde handler, but only if all
817+
* children are native.
818+
*/
819+
private def convertToCometIfAllChildrenAreNative(
820+
op: SparkPlan,
821+
handler: CometOperatorSerde[_]): Option[SparkPlan] = {
822+
if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
823+
convertToComet(op, handler)
824+
} else {
825+
None
826+
}
827+
}
828+
829+
/** Convert a Spark plan to a Comet plan using the specified serde handler */
830+
private def convertToComet(op: SparkPlan, handler: CometOperatorSerde[_]): Option[SparkPlan] = {
831+
val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
832+
if (isOperatorEnabled(serde, op)) {
833+
val builder = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
834+
if (op.children.forall(_.isInstanceOf[CometNativeExec])) {
835+
val childOp = op.children.map(_.asInstanceOf[CometNativeExec].nativeOp)
836+
childOp.foreach(builder.addChildren)
837+
return serde
838+
.convert(op, builder, childOp: _*)
839+
.map(nativeOp => serde.createExec(nativeOp, op))
840+
} else {
841+
return serde
842+
.convert(op, builder)
843+
.map(nativeOp => serde.createExec(nativeOp, op))
844+
}
845+
}
846+
None
847+
}
848+
849+
private def isOperatorEnabled(
850+
handler: CometOperatorSerde[SparkPlan],
851+
op: SparkPlan): Boolean = {
915852
val opName = op.getClass.getSimpleName
916-
if (enabled) {
917-
val opSerde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]]
918-
opSerde.getSupportLevel(op) match {
853+
if (handler.enabledConfig.forall(_.get(op.conf))) {
854+
handler.getSupportLevel(op) match {
919855
case Unsupported(notes) =>
920856
withInfo(op, notes.getOrElse(""))
921857
false
@@ -952,36 +888,4 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
952888
false
953889
}
954890
}
955-
956-
/**
957-
* Whether the input Spark operator `op` can be considered as a Comet sink, i.e., the start of
958-
* native execution. If it is true, we'll wrap `op` with `CometScanWrapper` or
959-
* `CometSinkPlaceHolder` later in `CometSparkSessionExtensions` after `operator2proto` is
960-
* called.
961-
*/
962-
private def isCometSink(op: SparkPlan): Boolean = {
963-
if (isExchangeSink(op)) {
964-
return true
965-
}
966-
op match {
967-
case s if isCometScan(s) => true
968-
case _: CometSparkToColumnarExec => true
969-
case _: CometSinkPlaceHolder => true
970-
case _ => false
971-
}
972-
}
973-
974-
private def isExchangeSink(op: SparkPlan): Boolean = {
975-
op match {
976-
case _: ShuffleExchangeExec => true
977-
case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
978-
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true
979-
case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
980-
case BroadcastQueryStageExec(_, ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) =>
981-
true
982-
case _: BroadcastExchangeExec => true
983-
case _ => false
984-
}
985-
}
986-
987891
}

0 commit comments

Comments
 (0)