@@ -31,14 +31,18 @@ import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS
3131import org .apache .spark .SparkConf
3232import org .apache .spark .benchmark .Benchmark
3333import org .apache .spark .sql .{DataFrame , DataFrameWriter , Row , SparkSession }
34+ import org .apache .spark .sql .comet ._
35+ import org .apache .spark .sql .comet .execution .shuffle .CometShuffleExchangeExec
36+ import org .apache .spark .sql .execution .{ColumnarToRowExec , InputAdapter , SparkPlan , WholeStageCodegenExec }
37+ import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
3438import org .apache .spark .sql .execution .benchmark .SqlBasedBenchmark
3539import org .apache .spark .sql .internal .SQLConf
3640import org .apache .spark .sql .types .DecimalType
3741
3842import org .apache .comet .CometConf
3943import org .apache .comet .CometSparkSessionExtensions
4044
41- trait CometBenchmarkBase extends SqlBasedBenchmark {
45+ trait CometBenchmarkBase extends SqlBasedBenchmark with AdaptiveSparkPlanHelper {
4246 override def getSparkSession : SparkSession = {
4347 val conf = new SparkConf ()
4448 .setAppName(" CometReadBenchmark" )
@@ -155,9 +159,53 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
155159 }
156160 }
157161
162+ // Check that the plan is fully Comet native before running the benchmark
163+ withSQLConf(cometExecConfigs.toSeq: _* ) {
164+ val df = spark.sql(query)
165+ val plan = stripAQEPlan(df.queryExecution.executedPlan)
166+ findFirstNonCometOperator(plan) match {
167+ case Some (op) =>
168+ // scalastyle:off println
169+ println()
170+ println(" =" * 80 )
171+ println(" WARNING: Benchmark plan is NOT fully Comet native!" )
172+ println(s " First non-Comet operator: ${op.nodeName}" )
173+ println(" =" * 80 )
174+ println(" Query plan:" )
175+ println(plan.treeString)
176+ println(" =" * 80 )
177+ println()
178+ // scalastyle:on println
179+ case None =>
180+ // All operators are Comet native, no warning needed
181+ }
182+ }
183+
158184 benchmark.run()
159185 }
160186
187+ /**
188+ * Finds the first non-Comet operator in the plan, if any. This is used to verify that
189+ * benchmarks are running fully on Comet native when expected.
190+ *
191+ * Based on CometTestBase.findFirstNonCometOperator.
192+ */
193+ protected def findFirstNonCometOperator (plan : SparkPlan ): Option [SparkPlan ] = {
194+ plan.foreach {
195+ case _ : CometNativeScanExec | _ : CometScanExec | _ : CometBatchScanExec |
196+ _ : CometIcebergNativeScanExec =>
197+ case _ : CometSinkPlaceHolder | _ : CometScanWrapper =>
198+ case _ : CometColumnarToRowExec =>
199+ case _ : CometSparkToColumnarExec =>
200+ case _ : CometExec | _ : CometShuffleExchangeExec =>
201+ case _ : CometBroadcastExchangeExec =>
202+ case _ : WholeStageCodegenExec | _ : ColumnarToRowExec | _ : InputAdapter =>
203+ case op =>
204+ return Some (op)
205+ }
206+ None
207+ }
208+
161209 protected def prepareTable (dir : File , df : DataFrame , partition : Option [String ] = None ): Unit = {
162210 val testDf = if (partition.isDefined) {
163211 df.write.partitionBy(partition.get)
0 commit comments