@@ -31,14 +31,19 @@ import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS
3131import org .apache .spark .SparkConf
3232import org .apache .spark .benchmark .Benchmark
3333import org .apache .spark .sql .{DataFrame , DataFrameWriter , Row , SparkSession }
34+ import org .apache .spark .sql .comet .CometPlanChecker
35+ import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
3436import org .apache .spark .sql .execution .benchmark .SqlBasedBenchmark
3537import org .apache .spark .sql .internal .SQLConf
3638import org .apache .spark .sql .types .DecimalType
3739
3840import org .apache .comet .CometConf
3941import org .apache .comet .CometSparkSessionExtensions
4042
41- trait CometBenchmarkBase extends SqlBasedBenchmark {
43+ trait CometBenchmarkBase
44+ extends SqlBasedBenchmark
45+ with AdaptiveSparkPlanHelper
46+ with CometPlanChecker {
4247 override def getSparkSession : SparkSession = {
4348 val conf = new SparkConf ()
4449 .setAppName(" CometReadBenchmark" )
@@ -88,28 +93,6 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
8893 }
8994 }
9095
91- /** Runs function `f` with Comet on and off. */
92- final def runWithComet (name : String , cardinality : Long )(f : => Unit ): Unit = {
93- val benchmark = new Benchmark (name, cardinality, output = output)
94-
95- benchmark.addCase(s " $name - Spark " ) { _ =>
96- withSQLConf(CometConf .COMET_ENABLED .key -> " false" ) {
97- f
98- }
99- }
100-
101- benchmark.addCase(s " $name - Comet " ) { _ =>
102- withSQLConf(
103- CometConf .COMET_ENABLED .key -> " true" ,
104- CometConf .COMET_EXEC_ENABLED .key -> " true" ,
105- SQLConf .ANSI_ENABLED .key -> " false" ) {
106- f
107- }
108- }
109-
110- benchmark.run()
111- }
112-
11396 /**
11497 * Runs an expression benchmark with standard cases: Spark, Comet (Scan), Comet (Scan + Exec).
11598 * This provides a consistent benchmark structure for expression evaluation.
@@ -149,6 +132,29 @@ trait CometBenchmarkBase extends SqlBasedBenchmark {
149132 CometConf .COMET_EXEC_ENABLED .key -> " true" ,
150133 " spark.sql.optimizer.constantFolding.enabled" -> " false" ) ++ extraCometConfigs
151134
135+ // Check that the plan is fully Comet native before running the benchmark
136+ withSQLConf(cometExecConfigs.toSeq: _* ) {
137+ val df = spark.sql(query)
138+ df.noop()
139+ val plan = stripAQEPlan(df.queryExecution.executedPlan)
140+ findFirstNonCometOperator(plan) match {
141+ case Some (op) =>
142+ // scalastyle:off println
143+ println()
144+ println(" =" * 80 )
145+ println(" WARNING: Benchmark plan is NOT fully Comet native!" )
146+ println(s " First non-Comet operator: ${op.nodeName}" )
147+ println(" =" * 80 )
148+ println(" Query plan:" )
149+ println(plan.treeString)
150+ println(" =" * 80 )
151+ println()
152+ // scalastyle:on println
153+ case None =>
154+ // All operators are Comet native, no warning needed
155+ }
156+ }
157+
152158 benchmark.addCase(" Comet (Scan + Exec)" ) { _ =>
153159 withSQLConf(cometExecConfigs.toSeq: _* ) {
154160 spark.sql(query).noop()
0 commit comments