Skip to content

Commit 4614e96

Browse files
committed
add checks to microbenchmarks for plan running natively in Comet
1 parent e041a1e commit 4614e96

1 file changed

Lines changed: 49 additions & 1 deletion

File tree

spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,18 @@ import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS
3131
import org.apache.spark.SparkConf
3232
import org.apache.spark.benchmark.Benchmark
3333
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}
34+
import org.apache.spark.sql.comet._
35+
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
36+
import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, SparkPlan, WholeStageCodegenExec}
37+
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
3438
import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark
3539
import org.apache.spark.sql.internal.SQLConf
3640
import org.apache.spark.sql.types.DecimalType
3741

3842
import org.apache.comet.CometConf
3943
import org.apache.comet.CometSparkSessionExtensions
4044

41-
trait CometBenchmarkBase extends SqlBasedBenchmark {
45+
trait CometBenchmarkBase extends SqlBasedBenchmark with AdaptiveSparkPlanHelper {
4246
override def getSparkSession: SparkSession = {
4347
val conf = new SparkConf()
4448
.setAppName("CometReadBenchmark")
@@ -155,9 +159,53 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
155159
}
156160
}
157161

162+
// Check that the plan is fully Comet native before running the benchmark
163+
withSQLConf(cometExecConfigs.toSeq: _*) {
164+
val df = spark.sql(query)
165+
val plan = stripAQEPlan(df.queryExecution.executedPlan)
166+
findFirstNonCometOperator(plan) match {
167+
case Some(op) =>
168+
// scalastyle:off println
169+
println()
170+
println("=" * 80)
171+
println("WARNING: Benchmark plan is NOT fully Comet native!")
172+
println(s"First non-Comet operator: ${op.nodeName}")
173+
println("=" * 80)
174+
println("Query plan:")
175+
println(plan.treeString)
176+
println("=" * 80)
177+
println()
178+
// scalastyle:on println
179+
case None =>
180+
// All operators are Comet native, no warning needed
181+
}
182+
}
183+
158184
benchmark.run()
159185
}
160186

187+
/**
188+
* Finds the first non-Comet operator in the plan, if any. This is used to verify that
189+
* benchmarks are running fully on Comet native when expected.
190+
*
191+
* Based on CometTestBase.findFirstNonCometOperator.
192+
*/
193+
protected def findFirstNonCometOperator(plan: SparkPlan): Option[SparkPlan] = {
194+
plan.foreach {
195+
case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec |
196+
_: CometIcebergNativeScanExec =>
197+
case _: CometSinkPlaceHolder | _: CometScanWrapper =>
198+
case _: CometColumnarToRowExec =>
199+
case _: CometSparkToColumnarExec =>
200+
case _: CometExec | _: CometShuffleExchangeExec =>
201+
case _: CometBroadcastExchangeExec =>
202+
case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter =>
203+
case op =>
204+
return Some(op)
205+
}
206+
None
207+
}
208+
161209
protected def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = {
162210
val testDf = if (partition.isDefined) {
163211
df.write.partitionBy(partition.get)

0 commit comments

Comments
 (0)