Skip to content

Commit c8dd7df

Browse files
committed
refactor
1 parent 150e21a commit c8dd7df

3 files changed

Lines changed: 76 additions & 55 deletions

File tree

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser}
3838
import org.apache.spark._
3939
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER}
4040
import org.apache.spark.sql.comet._
41+
import org.apache.spark.sql.comet.CometPlanChecker
4142
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
4243
import org.apache.spark.sql.execution._
4344
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -58,7 +59,8 @@ abstract class CometTestBase
5859
with BeforeAndAfterEach
5960
with AdaptiveSparkPlanHelper
6061
with ShimCometSparkSessionExtensions
61-
with ShimCometTestBase {
62+
with ShimCometTestBase
63+
with CometPlanChecker {
6264
import testImplicits._
6365

6466
protected val shuffleManager: String =
@@ -396,26 +398,6 @@ abstract class CometTestBase
396398
checkPlanNotMissingInput(plan)
397399
}
398400

399-
protected def findFirstNonCometOperator(
400-
plan: SparkPlan,
401-
excludedClasses: Class[_]*): Option[SparkPlan] = {
402-
val wrapped = wrapCometSparkToColumnar(plan)
403-
wrapped.foreach {
404-
case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec |
405-
_: CometIcebergNativeScanExec =>
406-
case _: CometSinkPlaceHolder | _: CometScanWrapper =>
407-
case _: CometColumnarToRowExec =>
408-
case _: CometSparkToColumnarExec =>
409-
case _: CometExec | _: CometShuffleExchangeExec =>
410-
case _: CometBroadcastExchangeExec =>
411-
case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter =>
412-
case op if !excludedClasses.exists(c => c.isAssignableFrom(op.getClass)) =>
413-
return Some(op)
414-
case _ =>
415-
}
416-
None
417-
}
418-
419401
// checks the plan node has no missing inputs
420402
// such nodes represented in plan with exclamation mark !
421403
// example: !CometWindowExec
@@ -449,14 +431,6 @@ abstract class CometTestBase
449431
}
450432
}
451433

452-
/** Wraps the CometRowToColumn as ScanWrapper, so the child operators will not be checked */
453-
private def wrapCometSparkToColumnar(plan: SparkPlan): SparkPlan = {
454-
plan.transformDown {
455-
// don't care the native operators
456-
case p: CometSparkToColumnarExec => CometScanWrapper(null, p)
457-
}
458-
}
459-
460434
private var _spark: SparkSessionType = _
461435
override protected implicit def spark: SparkSessionType = _spark
462436
protected implicit def sqlContext: SQLContext = _spark.sqlContext

spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS
3131
import org.apache.spark.SparkConf
3232
import org.apache.spark.benchmark.Benchmark
3333
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}
34-
import org.apache.spark.sql.comet._
35-
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
36-
import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec}
34+
import org.apache.spark.sql.comet.CometPlanChecker
3735
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3836
import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark
3937
import org.apache.spark.sql.internal.SQLConf
@@ -42,7 +40,10 @@ import org.apache.spark.sql.types.DecimalType
4240
import org.apache.comet.CometConf
4341
import org.apache.comet.CometSparkSessionExtensions
4442

45-
trait CometBenchmarkBase extends SqlBasedBenchmark with AdaptiveSparkPlanHelper {
43+
trait CometBenchmarkBase
44+
extends SqlBasedBenchmark
45+
with AdaptiveSparkPlanHelper
46+
with CometPlanChecker {
4647
override def getSparkSession: SparkSession = {
4748
val conf = new SparkConf()
4849
.setAppName("CometReadBenchmark")
@@ -163,28 +164,6 @@ trait CometBenchmarkBase extends SqlBasedBenchmark with AdaptiveSparkPlanHelper
163164
benchmark.run()
164165
}
165166

166-
/**
167-
* Finds the first non-Comet operator in the plan, if any. This is used to verify that
168-
* benchmarks are running fully on Comet native when expected.
169-
*
170-
* Based on CometTestBase.findFirstNonCometOperator.
171-
*/
172-
protected def findFirstNonCometOperator(plan: SparkPlan): Option[SparkPlan] = {
173-
plan.foreach {
174-
case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec |
175-
_: CometIcebergNativeScanExec =>
176-
case _: CometSinkPlaceHolder | _: CometScanWrapper =>
177-
case _: CometColumnarToRowExec =>
178-
case _: CometSparkToColumnarExec =>
179-
case _: CometExec | _: CometShuffleExchangeExec =>
180-
case _: CometBroadcastExchangeExec =>
181-
case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter =>
182-
case op =>
183-
return Some(op)
184-
}
185-
None
186-
}
187-
188167
protected def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = {
189168
val testDf = if (partition.isDefined) {
190169
df.write.partitionBy(partition.get)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.spark.sql.comet
21+
22+
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
23+
import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec}
24+
25+
/**
26+
* Trait providing utilities to check if a Spark plan is fully running on Comet native operators.
27+
* Used by both CometTestBase and CometBenchmarkBase.
28+
*/
29+
trait CometPlanChecker {
30+
31+
/**
32+
* Finds the first non-Comet operator in the plan, if any.
33+
*
34+
* @param plan
35+
* The SparkPlan to check
36+
* @param excludedClasses
37+
* Classes to exclude from the check (these are allowed to be non-Comet)
38+
* @return
39+
* Some(operator) if a non-Comet operator is found, None otherwise
40+
*/
41+
protected def findFirstNonCometOperator(
42+
plan: SparkPlan,
43+
excludedClasses: Class[_]*): Option[SparkPlan] = {
44+
val wrapped = wrapCometSparkToColumnar(plan)
45+
wrapped.foreach {
46+
case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec |
47+
_: CometIcebergNativeScanExec =>
48+
case _: CometSinkPlaceHolder | _: CometScanWrapper =>
49+
case _: CometColumnarToRowExec =>
50+
case _: CometSparkToColumnarExec =>
51+
case _: CometExec | _: CometShuffleExchangeExec =>
52+
case _: CometBroadcastExchangeExec =>
53+
case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter =>
54+
case op if !excludedClasses.exists(c => c.isAssignableFrom(op.getClass)) =>
55+
return Some(op)
56+
case _ =>
57+
}
58+
None
59+
}
60+
61+
/** Wraps the CometSparkToColumnar as ScanWrapper, so the child operators will not be checked */
62+
private def wrapCometSparkToColumnar(plan: SparkPlan): SparkPlan = {
63+
plan.transformDown {
64+
// don't care the native operators
65+
case p: CometSparkToColumnarExec => CometScanWrapper(null, p)
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)