1919
2020package org .apache .comet .rules
2121
22- import scala .jdk .CollectionConverters ._
23-
2422import org .apache .spark .sql .SparkSession
2523import org .apache .spark .sql .catalyst .expressions .{Divide , DoubleLiteral , EqualNullSafe , EqualTo , Expression , FloatLiteral , GreaterThan , GreaterThanOrEqual , KnownFloatingPointNormalized , LessThan , LessThanOrEqual , NamedExpression , Remainder }
2624import org .apache .spark .sql .catalyst .optimizer .NormalizeNaNAndZero
@@ -46,7 +44,6 @@ import org.apache.comet.CometSparkSessionExtensions._
4644import org .apache .comet .rules .CometExecRule .allExecs
4745import org .apache .comet .serde .{CometOperatorSerde , Compatible , Incompatible , OperatorOuterClass , QueryPlanSerde , Unsupported }
4846import org .apache .comet .serde .OperatorOuterClass .Operator
49- import org .apache .comet .serde .QueryPlanSerde .{serializeDataType , supportedDataType }
5047import org .apache .comet .serde .operator ._
5148import org .apache .comet .serde .operator .CometDataWritingCommand
5249
@@ -71,13 +68,6 @@ object CometExecRule {
7168 classOf [LocalTableScanExec ] -> CometLocalTableScanExec ,
7269 classOf [WindowExec ] -> CometWindowExec )
7370
74- /**
75- * DataWritingCommandExec is handled separately in convertNode since it doesn't follow the
76- * standard pattern of having CometNativeExec children.
77- */
78- val writeExecs : Map [Class [_ <: SparkPlan ], CometOperatorSerde [_]] =
79- Map (classOf [DataWritingCommandExec ] -> CometDataWritingCommand )
80-
8171 /**
8272 * Sinks that have a native plan of ScanExec.
8373 */
@@ -186,57 +176,33 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
186176 def convertNode (op : SparkPlan ): SparkPlan = op match {
187177 // Fully native scan for V1
188178 case scan : CometScanExec if scan.scanImpl == CometConf .SCAN_NATIVE_DATAFUSION =>
189- val nativeOp = operator2Proto(scan).get
190- CometNativeScan .createExec(nativeOp, scan)
179+ convertToComet(scan, CometNativeScan ).getOrElse(scan)
191180
192181 // Fully native Iceberg scan for V2 (iceberg-rust path)
193182 // Only handle scans with native metadata; SupportsComet scans fall through to isCometScan
194183 // Config checks (COMET_ICEBERG_NATIVE_ENABLED, COMET_EXEC_ENABLED) are done in CometScanRule
195184 case scan : CometBatchScanExec if scan.nativeIcebergScanMetadata.isDefined =>
196- operator2Proto(scan) match {
197- case Some (nativeOp) =>
198- CometIcebergNativeScan .createExec(nativeOp, scan)
199- case None =>
200- // Serialization failed, fall back to CometBatchScanExec
201- scan
202- }
185+ convertToComet(scan, CometIcebergNativeScan ).getOrElse(scan)
203186
204187 // Comet JVM + native scan for V1 and V2
205188 case op if isCometScan(op) =>
206- val nativeOp = operator2Proto(op)
207- CometScanWrapper (nativeOp.get, op)
189+ convertToComet(op, CometScanWrapper ).getOrElse(op)
208190
209191 case op if shouldApplySparkToColumnar(conf, op) =>
210- val cometOp = CometSparkToColumnarExec (op)
211- val nativeOp = operator2Proto(cometOp)
212- CometScanWrapper (nativeOp.get, cometOp)
213-
214- // Handle DataWritingCommandExec specially since it doesn't follow the standard pattern
215- case exec : DataWritingCommandExec =>
216- CometExecRule .writeExecs.get(classOf [DataWritingCommandExec ]) match {
217- case Some (handler) if isOperatorEnabled(handler, exec) =>
218- val builder = OperatorOuterClass .Operator .newBuilder().setPlanId(exec.id)
219- handler
220- .asInstanceOf [CometOperatorSerde [DataWritingCommandExec ]]
221- .convert(exec, builder)
222- .map(nativeOp =>
223- handler
224- .asInstanceOf [CometOperatorSerde [DataWritingCommandExec ]]
225- .createExec(nativeOp, exec))
226- .getOrElse(exec)
227- case _ =>
228- exec
229- }
192+ convertToComet(op, CometSparkToColumnarExec ).getOrElse(op)
193+
194+ case op : DataWritingCommandExec =>
195+ convertToComet(op, CometDataWritingCommand ).getOrElse(op)
230196
231197 // For AQE broadcast stage on a Comet broadcast exchange
232198 case s @ BroadcastQueryStageExec (_, _ : CometBroadcastExchangeExec , _) =>
233- newPlanWithProto (s, CometSinkPlaceHolder (_, s, s) )
199+ convertToCometIfAllChildrenAreNative (s, CometExchangeSink ).getOrElse(s )
234200
235201 case s @ BroadcastQueryStageExec (
236202 _,
237203 ReusedExchangeExec (_, _ : CometBroadcastExchangeExec ),
238204 _) =>
239- newPlanWithProto (s, CometSinkPlaceHolder (_, s, s) )
205+ convertToCometIfAllChildrenAreNative (s, CometExchangeSink ).getOrElse(s )
240206
241207 // `CometBroadcastExchangeExec`'s broadcast output is not compatible with Spark's broadcast
242208 // exchange. It is only used for Comet native execution. We only transform Spark broadcast
@@ -273,37 +239,26 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
273239
274240 // For AQE shuffle stage on a Comet shuffle exchange
275241 case s @ ShuffleQueryStageExec (_, _ : CometShuffleExchangeExec , _) =>
276- newPlanWithProto (s, CometSinkPlaceHolder (_, s, s) )
242+ convertToCometIfAllChildrenAreNative (s, CometExchangeSink ).getOrElse(s )
277243
278244 // For AQE shuffle stage on a reused Comet shuffle exchange
279245 // Note that we don't need to handle `ReusedExchangeExec` for non-AQE case, because
280246 // the query plan won't be re-optimized/planned in non-AQE mode.
281247 case s @ ShuffleQueryStageExec (_, ReusedExchangeExec (_, _ : CometShuffleExchangeExec ), _) =>
282- newPlanWithProto (s, CometSinkPlaceHolder (_, s, s) )
248+ convertToCometIfAllChildrenAreNative (s, CometExchangeSink ).getOrElse(s )
283249
284250 case s : ShuffleExchangeExec =>
285251 // try native shuffle first, then columnar shuffle, then fall back to Spark
286252 // if neither are supported
287253 tryNativeShuffle(s).orElse(tryColumnarShuffle(s)).getOrElse(s)
288254
289255 case op =>
290- allExecs
256+ val handler = allExecs
291257 .get(op.getClass)
292- .map(_.asInstanceOf [CometOperatorSerde [SparkPlan ]]) match {
258+ .map(_.asInstanceOf [CometOperatorSerde [SparkPlan ]])
259+ handler match {
293260 case Some (handler) =>
294- if (op.children.forall(isCometNative)) {
295- if (isOperatorEnabled(handler, op)) {
296- val builder = OperatorOuterClass .Operator .newBuilder().setPlanId(op.id)
297- val childOp = op.children.map(_.asInstanceOf [CometNativeExec ].nativeOp)
298- childOp.foreach(builder.addChildren)
299- return handler
300- .convert(op, builder, childOp : _* )
301- .map(handler.createExec(_, op))
302- .getOrElse(op)
303- }
304- } else {
305- return op
306- }
261+ return convertToCometIfAllChildrenAreNative(op, handler).getOrElse(op)
307262 case _ =>
308263 }
309264
@@ -332,25 +287,11 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
332287 }
333288 }
334289
335- private def operator2ProtoIfAllChildrenAreNative (op : SparkPlan ): Option [Operator ] = {
336- if (op.children.forall(_.isInstanceOf [CometNativeExec ])) {
337- operator2Proto(op, op.children.map(_.asInstanceOf [CometNativeExec ].nativeOp): _* )
338- } else {
339- None
340- }
341- }
342-
343- /**
344- * Convert operator to proto and then apply a transformation to wrap the proto in a new plan.
345- */
346- private def newPlanWithProto (op : SparkPlan , fun : Operator => SparkPlan ): SparkPlan = {
347- operator2ProtoIfAllChildrenAreNative(op).map(fun).getOrElse(op)
348- }
349-
350290 private def tryNativeShuffle (s : ShuffleExchangeExec ): Option [SparkPlan ] = {
351291 Some (s)
352- .filter(_ => nativeShuffleSupported(s))
353- .flatMap(_ => operator2ProtoIfAllChildrenAreNative(s))
292+ .filter(nativeShuffleSupported)
293+ .filter(_.children.forall(_.isInstanceOf [CometNativeExec ]))
294+ .flatMap(_ => operator2Proto(s))
354295 .map { nativeOp =>
355296 // Switch to use Decimal128 regardless of precision, since Arrow native execution
356297 // doesn't support Decimal32 and Decimal64 yet.
@@ -366,7 +307,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
366307 // If the child of ShuffleExchangeExec is also a ShuffleExchangeExec, we should not
367308 // convert it to CometColumnarShuffle,
368309 Some (s)
369- .filter(_ => columnarShuffleSupported(s) )
310+ .filter(columnarShuffleSupported)
370311 .flatMap(_ => operator2Proto(s))
371312 .flatMap { nativeOp =>
372313 s.child match {
@@ -819,84 +760,45 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
819760 }
820761
821762 /**
822- * Convert a Spark plan operator to a protobuf Comet operator.
823- *
824- * @param op
825- * Spark plan operator
826- * @param childOp
827- * previously converted protobuf Comet operators, which will be consumed by the Spark plan
828- * operator as its children
829- * @return
830- * The converted Comet native operator for the input `op`, or `None` if the `op` cannot be
831- * converted to a native operator.
763+ * Fallback for handling sinks that have not been handled explicitly. This method should
764+ * eventually be removed once CometExecRule fully uses the operator serde framework.
832765 */
833766 private def operator2Proto (op : SparkPlan , childOp : Operator * ): Option [Operator ] = {
767+
768+ def isCometSink (op : SparkPlan ): Boolean = {
769+ op match {
770+ case _ : CometSparkToColumnarExec => true
771+ case _ : CometSinkPlaceHolder => true
772+ case _ => false
773+ }
774+ }
775+
776+ def isExchangeSink (op : SparkPlan ): Boolean = {
777+ op match {
778+ case _ : ShuffleExchangeExec => true
779+ case ShuffleQueryStageExec (_, _ : CometShuffleExchangeExec , _) => true
780+ case ShuffleQueryStageExec (_, ReusedExchangeExec (_, _ : CometShuffleExchangeExec ), _) =>
781+ true
782+ case BroadcastQueryStageExec (_, _ : CometBroadcastExchangeExec , _) => true
783+ case BroadcastQueryStageExec (
784+ _,
785+ ReusedExchangeExec (_, _ : CometBroadcastExchangeExec ),
786+ _) =>
787+ true
788+ case _ : BroadcastExchangeExec => true
789+ case _ => false
790+ }
791+ }
792+
834793 val builder = OperatorOuterClass .Operator .newBuilder().setPlanId(op.id)
835794 childOp.foreach(builder.addChildren)
836795
837796 op match {
838-
839- // Fully native scan for V1
840- case scan : CometScanExec if scan.scanImpl == CometConf .SCAN_NATIVE_DATAFUSION =>
841- CometNativeScan .convert(scan, builder, childOp : _* )
842-
843- // Fully native Iceberg scan for V2 (iceberg-rust path)
844- case scan : CometBatchScanExec if scan.nativeIcebergScanMetadata.isDefined =>
845- CometIcebergNativeScan .convert(scan, builder, childOp : _* )
797+ case op if isExchangeSink(op) =>
798+ CometExchangeSink .convert(op, builder, childOp : _* )
846799
847800 case op if isCometSink(op) =>
848- val supportedTypes =
849- op.output.forall(a => supportedDataType(a.dataType, allowComplex = true ))
850-
851- if (! supportedTypes) {
852- withInfo(op, " Unsupported data type" )
853- return None
854- }
855-
856- // These operators are source of Comet native execution chain
857- val scanBuilder = OperatorOuterClass .Scan .newBuilder()
858- val source = op.simpleStringWithNodeId()
859- if (source.isEmpty) {
860- scanBuilder.setSource(op.getClass.getSimpleName)
861- } else {
862- scanBuilder.setSource(source)
863- }
864-
865- val ffiSafe = op match {
866- case _ if isExchangeSink(op) =>
867- // Source of broadcast exchange batches is ArrowStreamReader
868- // Source of shuffle exchange batches is NativeBatchDecoderIterator
869- true
870- case scan : CometScanExec if scan.scanImpl == CometConf .SCAN_NATIVE_COMET =>
871- // native_comet scan reuses mutable buffers
872- false
873- case scan : CometScanExec if scan.scanImpl == CometConf .SCAN_NATIVE_ICEBERG_COMPAT =>
874- // native_iceberg_compat scan reuses mutable buffers for constant columns
875- // https://github.com/apache/datafusion-comet/issues/2152
876- false
877- case _ =>
878- false
879- }
880- scanBuilder.setArrowFfiSafe(ffiSafe)
881-
882- val scanTypes = op.output.flatten { attr =>
883- serializeDataType(attr.dataType)
884- }
885-
886- if (scanTypes.length == op.output.length) {
887- scanBuilder.addAllFields(scanTypes.asJava)
888-
889- // Sink operators don't have children
890- builder.clearChildren()
891-
892- Some (builder.setScan(scanBuilder).build())
893- } else {
894- // There are unsupported scan type
895- withInfo(
896- op,
897- s " unsupported Comet operator: ${op.nodeName}, due to unsupported data types above " )
898- None
899- }
801+ CometScanWrapper .convert(op, builder, childOp : _* )
900802
901803 case _ =>
902804 // Emit warning if:
@@ -910,12 +812,46 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
910812 }
911813 }
912814
913- private def isOperatorEnabled (handler : CometOperatorSerde [_], op : SparkPlan ): Boolean = {
914- val enabled = handler.enabledConfig.forall(_.get(op.conf))
815+ /**
816+ * Convert a Spark plan to a Comet plan using the specified serde handler, but only if all
817+ * children are native.
818+ */
819+ private def convertToCometIfAllChildrenAreNative (
820+ op : SparkPlan ,
821+ handler : CometOperatorSerde [_]): Option [SparkPlan ] = {
822+ if (op.children.forall(_.isInstanceOf [CometNativeExec ])) {
823+ convertToComet(op, handler)
824+ } else {
825+ None
826+ }
827+ }
828+
829+ /** Convert a Spark plan to a Comet plan using the specified serde handler */
830+ private def convertToComet (op : SparkPlan , handler : CometOperatorSerde [_]): Option [SparkPlan ] = {
831+ val serde = handler.asInstanceOf [CometOperatorSerde [SparkPlan ]]
832+ if (isOperatorEnabled(serde, op)) {
833+ val builder = OperatorOuterClass .Operator .newBuilder().setPlanId(op.id)
834+ if (op.children.forall(_.isInstanceOf [CometNativeExec ])) {
835+ val childOp = op.children.map(_.asInstanceOf [CometNativeExec ].nativeOp)
836+ childOp.foreach(builder.addChildren)
837+ return serde
838+ .convert(op, builder, childOp : _* )
839+ .map(nativeOp => serde.createExec(nativeOp, op))
840+ } else {
841+ return serde
842+ .convert(op, builder)
843+ .map(nativeOp => serde.createExec(nativeOp, op))
844+ }
845+ }
846+ None
847+ }
848+
849+ private def isOperatorEnabled (
850+ handler : CometOperatorSerde [SparkPlan ],
851+ op : SparkPlan ): Boolean = {
915852 val opName = op.getClass.getSimpleName
916- if (enabled) {
917- val opSerde = handler.asInstanceOf [CometOperatorSerde [SparkPlan ]]
918- opSerde.getSupportLevel(op) match {
853+ if (handler.enabledConfig.forall(_.get(op.conf))) {
854+ handler.getSupportLevel(op) match {
919855 case Unsupported (notes) =>
920856 withInfo(op, notes.getOrElse(" " ))
921857 false
@@ -952,36 +888,4 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
952888 false
953889 }
954890 }
955-
956- /**
957- * Whether the input Spark operator `op` can be considered as a Comet sink, i.e., the start of
958- * native execution. If it is true, we'll wrap `op` with `CometScanWrapper` or
959- * `CometSinkPlaceHolder` later in `CometSparkSessionExtensions` after `operator2proto` is
960- * called.
961- */
962- private def isCometSink (op : SparkPlan ): Boolean = {
963- if (isExchangeSink(op)) {
964- return true
965- }
966- op match {
967- case s if isCometScan(s) => true
968- case _ : CometSparkToColumnarExec => true
969- case _ : CometSinkPlaceHolder => true
970- case _ => false
971- }
972- }
973-
974- private def isExchangeSink (op : SparkPlan ): Boolean = {
975- op match {
976- case _ : ShuffleExchangeExec => true
977- case ShuffleQueryStageExec (_, _ : CometShuffleExchangeExec , _) => true
978- case ShuffleQueryStageExec (_, ReusedExchangeExec (_, _ : CometShuffleExchangeExec ), _) => true
979- case BroadcastQueryStageExec (_, _ : CometBroadcastExchangeExec , _) => true
980- case BroadcastQueryStageExec (_, ReusedExchangeExec (_, _ : CometBroadcastExchangeExec ), _) =>
981- true
982- case _ : BroadcastExchangeExec => true
983- case _ => false
984- }
985- }
986-
987891}
0 commit comments