diff --git a/core/src/test/scala/org/apache/spark/CheckErrorHelper.scala b/core/src/test/scala/org/apache/spark/CheckErrorHelper.scala new file mode 100644 index 0000000000000..d01600bb439f1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/CheckErrorHelper.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ + +import org.scalatest.Suite + +trait CheckErrorHelper { self: Suite => + + case class ExpectedContext( + contextType: QueryContextType, + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String, + callSitePattern: String + ) + + object ExpectedContext { + def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { + ExpectedContext("", "", start, stop, fragment) + } + + // Check the fragment only. This is only used when the fragment is distinguished within + // the query text + def apply(fragment: String): ExpectedContext = { + ExpectedContext("", "", -1, -1, fragment) + } + + def apply( + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String): ExpectedContext = { + new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex, + fragment, "") + } + + def apply(fragment: String, callSitePattern: String): ExpectedContext = { + new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, fragment, callSitePattern) + } + } + + /** + * Parameter keys that are omitted from comparison when absent from the expected map. + * For each error condition, the set lists keys that are removed from the actual + * exception parameters before comparison with the expected map. + * Test suites may override this to add or change ignorable parameters per condition. + */ + protected def checkErrorIgnorableParameters: Map[String, Set[String]] = Map( + "TABLE_OR_VIEW_NOT_FOUND" -> Set("searchPath") + ) + + /** + * Checks an exception with an error condition against expected results. + * @param exception The exception to check + * @param condition The expected error condition identifying the error + * @param sqlState Optional the expected SQLSTATE, not verified if not supplied + * @param parameters A map of parameter names and values. The names are as defined + * in the error-classes file. + * @param matchPVals Optionally treat the parameters value as regular expression pattern. + * false if not supplied. + */ + protected def checkError( + exception: SparkThrowable, + condition: String, + sqlState: Option[String] = None, + parameters: Map[String, String] = Map.empty, + matchPVals: Boolean = false, + queryContext: Array[ExpectedContext] = Array.empty): Unit = { + val mismatches = new ListBuffer[String] + + if (exception.getCondition != condition) { + mismatches += s"condition: expected '$condition' but got '${exception.getCondition}'" + } + sqlState.foreach { state => + if (exception.getSqlState != state) { + mismatches += s"sqlState: expected '$state' but got '${exception.getSqlState}'" + } + } + + val actualParameters = exception.getMessageParameters.asScala + val ignorable = checkErrorIgnorableParameters.getOrElse(condition, Set.empty[String]) + val actualParametersToCompare = actualParameters.filter { case (k, _) => + !ignorable.contains(k) || parameters.contains(k) + } + if (matchPVals) { + if (actualParametersToCompare.size != parameters.size) { + mismatches += s"parameters size: expected ${parameters.size} but got" + + s" ${actualParametersToCompare.size}" + } + actualParametersToCompare.foreach { case (key, actualVal) => + parameters.get(key) match { + case None => + mismatches += s"parameters: unexpected key '$key' with value '$actualVal'" + case Some(pattern) if !actualVal.matches(pattern) => + mismatches += s"parameters['$key']: value '$actualVal' does not match pattern" + + s" '$pattern'" + case _ => + } + } + parameters.keys.filterNot(actualParametersToCompare.contains).foreach { key => + mismatches += s"parameters: missing expected key '$key'" + } + } else if (actualParametersToCompare != parameters) { + mismatches += s"parameters: expected $parameters but got $actualParametersToCompare" + } + + val actualQueryContext = exception.getQueryContext() + if (actualQueryContext.length != queryContext.length) { + mismatches += s"queryContext.length: expected ${queryContext.length}" + + s" but got ${actualQueryContext.length}" + } + actualQueryContext.zip(queryContext).zipWithIndex.foreach { + case ((actual, expected), idx) => + if (actual.contextType() != expected.contextType) { + mismatches += s"queryContext[$idx].contextType: expected ${expected.contextType}" + + s" but got ${actual.contextType()}" + } + if (actual.contextType() == QueryContextType.SQL) { + if (actual.objectType() != expected.objectType) { + mismatches += s"queryContext[$idx].objectType: expected '${expected.objectType}'" + + s" but got '${actual.objectType()}'" + } + if (actual.objectName() != expected.objectName) { + mismatches += s"queryContext[$idx].objectName: expected '${expected.objectName}'" + + s" but got '${actual.objectName()}'" + } + // If startIndex and stopIndex are -1, it means we simply want to check the + // fragment of the query context. This should be the case when the fragment is + // distinguished within the query text. + if (expected.startIndex != -1 && actual.startIndex() != expected.startIndex) { + mismatches += s"queryContext[$idx].startIndex: expected ${expected.startIndex}" + + s" but got ${actual.startIndex()}" + } + if (expected.stopIndex != -1 && actual.stopIndex() != expected.stopIndex) { + mismatches += s"queryContext[$idx].stopIndex: expected ${expected.stopIndex}" + + s" but got ${actual.stopIndex()}" + } + if (actual.fragment() != expected.fragment) { + mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + + s" but got '${actual.fragment()}'" + } + } else if (actual.contextType() == QueryContextType.DataFrame) { + if (actual.fragment() != expected.fragment) { + mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + + s" but got '${actual.fragment()}'" + } + if (expected.callSitePattern.nonEmpty && + !actual.callSite().matches(expected.callSitePattern)) { + mismatches += s"queryContext[$idx].callSite: '${actual.callSite()}'" + + s" does not match pattern '${expected.callSitePattern}'" + } + } + } + + if (mismatches.nonEmpty) { + val sb = new StringBuilder + sb.append(s"checkError found ${mismatches.size} mismatch(es).\n\n") + sb.append("=== Actual Exception State ===\n") + sb.append(s" condition: ${exception.getCondition}\n") + sb.append(s" sqlState: ${exception.getSqlState}\n") + sb.append(s" parameters:\n") + if (actualParameters.isEmpty) { + sb.append(" (empty)\n") + } else { + actualParameters.foreach { case (k, v) => sb.append(s" $k -> $v\n") } + } + actualQueryContext.zipWithIndex.foreach { case (ctx, idx) => + sb.append(s" queryContext[$idx] (${ctx.contextType()}):\n") + if (ctx.contextType() == QueryContextType.SQL) { + sb.append(s" objectType: ${ctx.objectType()}\n") + sb.append(s" objectName: ${ctx.objectName()}\n") + sb.append(s" startIndex: ${ctx.startIndex()}\n") + sb.append(s" stopIndex: ${ctx.stopIndex()}\n") + sb.append(s" fragment: ${ctx.fragment()}\n") + } else if (ctx.contextType() == QueryContextType.DataFrame) { + sb.append(s" fragment: ${ctx.fragment()}\n") + sb.append(s" callSite: ${ctx.callSite()}\n") + } + } + sb.append("\n=== Mismatches ===\n") + mismatches.foreach(m => sb.append(s" $m\n")) + fail(sb.toString()) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/SparkTestSuite.scala b/core/src/test/scala/org/apache/spark/SparkTestSuite.scala index 10504684be9fd..0fd595bf3fdf3 100644 --- a/core/src/test/scala/org/apache/spark/SparkTestSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkTestSuite.scala @@ -22,8 +22,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files, Path} import java.util.{Locale, TimeZone} -import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.logging.log4j._ import org.apache.logging.log4j.core.{LogEvent, Logger, LoggerContext} @@ -70,6 +69,7 @@ trait SparkTestSuite with BeforeAndAfterEach with ThreadAudit with TimeLimits + with CheckErrorHelper with Logging { // scalastyle:on @@ -274,150 +274,6 @@ trait SparkTestSuite } } - /** - * Parameter keys that are omitted from comparison when absent from the expected map. - * For each error condition, the set lists keys that are removed from the actual - * exception parameters before comparison with the expected map. - * Test suites may override this to add or change ignorable parameters per condition. - */ - protected def checkErrorIgnorableParameters: Map[String, Set[String]] = Map( - "TABLE_OR_VIEW_NOT_FOUND" -> Set("searchPath") - ) - - /** - * Checks an exception with an error condition against expected results. - * @param exception The exception to check - * @param condition The expected error condition identifying the error - * @param sqlState Optional the expected SQLSTATE, not verified if not supplied - * @param parameters A map of parameter names and values. The names are as defined - * in the error-classes file. - * @param matchPVals Optionally treat the parameters value as regular expression pattern. - * false if not supplied. - */ - protected def checkError( - exception: SparkThrowable, - condition: String, - sqlState: Option[String] = None, - parameters: Map[String, String] = Map.empty, - matchPVals: Boolean = false, - queryContext: Array[ExpectedContext] = Array.empty): Unit = { - val mismatches = new ListBuffer[String] - - if (exception.getCondition != condition) { - mismatches += s"condition: expected '$condition' but got '${exception.getCondition}'" - } - sqlState.foreach { state => - if (exception.getSqlState != state) { - mismatches += s"sqlState: expected '$state' but got '${exception.getSqlState}'" - } - } - - val actualParameters = exception.getMessageParameters.asScala - val ignorable = checkErrorIgnorableParameters.getOrElse(condition, Set.empty[String]) - val actualParametersToCompare = actualParameters.filter { case (k, _) => - !ignorable.contains(k) || parameters.contains(k) - } - if (matchPVals) { - if (actualParametersToCompare.size != parameters.size) { - mismatches += s"parameters size: expected ${parameters.size} but got" + - s" ${actualParametersToCompare.size}" - } - actualParametersToCompare.foreach { case (key, actualVal) => - parameters.get(key) match { - case None => - mismatches += s"parameters: unexpected key '$key' with value '$actualVal'" - case Some(pattern) if !actualVal.matches(pattern) => - mismatches += s"parameters['$key']: value '$actualVal' does not match pattern" + - s" '$pattern'" - case _ => - } - } - parameters.keys.filterNot(actualParametersToCompare.contains).foreach { key => - mismatches += s"parameters: missing expected key '$key'" - } - } else if (actualParametersToCompare != parameters) { - mismatches += s"parameters: expected $parameters but got $actualParametersToCompare" - } - - val actualQueryContext = exception.getQueryContext() - if (actualQueryContext.length != queryContext.length) { - mismatches += s"queryContext.length: expected ${queryContext.length}" + - s" but got ${actualQueryContext.length}" - } - actualQueryContext.zip(queryContext).zipWithIndex.foreach { - case ((actual, expected), idx) => - if (actual.contextType() != expected.contextType) { - mismatches += s"queryContext[$idx].contextType: expected ${expected.contextType}" + - s" but got ${actual.contextType()}" - } - if (actual.contextType() == QueryContextType.SQL) { - if (actual.objectType() != expected.objectType) { - mismatches += s"queryContext[$idx].objectType: expected '${expected.objectType}'" + - s" but got '${actual.objectType()}'" - } - if (actual.objectName() != expected.objectName) { - mismatches += s"queryContext[$idx].objectName: expected '${expected.objectName}'" + - s" but got '${actual.objectName()}'" - } - // If startIndex and stopIndex are -1, it means we simply want to check the - // fragment of the query context. This should be the case when the fragment is - // distinguished within the query text. - if (expected.startIndex != -1 && actual.startIndex() != expected.startIndex) { - mismatches += s"queryContext[$idx].startIndex: expected ${expected.startIndex}" + - s" but got ${actual.startIndex()}" - } - if (expected.stopIndex != -1 && actual.stopIndex() != expected.stopIndex) { - mismatches += s"queryContext[$idx].stopIndex: expected ${expected.stopIndex}" + - s" but got ${actual.stopIndex()}" - } - if (actual.fragment() != expected.fragment) { - mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + - s" but got '${actual.fragment()}'" - } - } else if (actual.contextType() == QueryContextType.DataFrame) { - if (actual.fragment() != expected.fragment) { - mismatches += s"queryContext[$idx].fragment: expected '${expected.fragment}'" + - s" but got '${actual.fragment()}'" - } - if (expected.callSitePattern.nonEmpty && - !actual.callSite().matches(expected.callSitePattern)) { - mismatches += s"queryContext[$idx].callSite: '${actual.callSite()}'" + - s" does not match pattern '${expected.callSitePattern}'" - } - } - } - - if (mismatches.nonEmpty) { - val sb = new StringBuilder - sb.append(s"checkError found ${mismatches.size} mismatch(es).\n\n") - sb.append("=== Actual Exception State ===\n") - sb.append(s" condition: ${exception.getCondition}\n") - sb.append(s" sqlState: ${exception.getSqlState}\n") - sb.append(s" parameters:\n") - if (actualParameters.isEmpty) { - sb.append(" (empty)\n") - } else { - actualParameters.foreach { case (k, v) => sb.append(s" $k -> $v\n") } - } - actualQueryContext.zipWithIndex.foreach { case (ctx, idx) => - sb.append(s" queryContext[$idx] (${ctx.contextType()}):\n") - if (ctx.contextType() == QueryContextType.SQL) { - sb.append(s" objectType: ${ctx.objectType()}\n") - sb.append(s" objectName: ${ctx.objectName()}\n") - sb.append(s" startIndex: ${ctx.startIndex()}\n") - sb.append(s" stopIndex: ${ctx.stopIndex()}\n") - sb.append(s" fragment: ${ctx.fragment()}\n") - } else if (ctx.contextType() == QueryContextType.DataFrame) { - sb.append(s" fragment: ${ctx.fragment()}\n") - sb.append(s" callSite: ${ctx.callSite()}\n") - } - } - sb.append("\n=== Mismatches ===\n") - mismatches.foreach(m => sb.append(s" $m\n")) - fail(sb.toString()) - } - } - protected def checkError( exception: SparkThrowable, condition: String, @@ -524,42 +380,6 @@ trait SparkTestSuite condition = "TABLE_OR_VIEW_ALREADY_EXISTS", parameters = Map("relationName" -> tableName)) - case class ExpectedContext( - contextType: QueryContextType, - objectType: String, - objectName: String, - startIndex: Int, - stopIndex: Int, - fragment: String, - callSitePattern: String - ) - - object ExpectedContext { - def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { - ExpectedContext("", "", start, stop, fragment) - } - - // Check the fragment only. This is only used when the fragment is distinguished within - // the query text - def apply(fragment: String): ExpectedContext = { - ExpectedContext("", "", -1, -1, fragment) - } - - def apply( - objectType: String, - objectName: String, - startIndex: Int, - stopIndex: Int, - fragment: String): ExpectedContext = { - new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex, - fragment, "") - } - - def apply(fragment: String, callSitePattern: String): ExpectedContext = { - new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, fragment, callSitePattern) - } - } - class LogAppender(msg: String = "", maxEvents: Int = 1000) extends AbstractAppender("logAppender", null, null, true, Property.EMPTY_ARRAY) { private val _loggingEvents = new ArrayBuffer[LogEvent]() diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala index c27a83b79b89f..aad7f14b575de 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala @@ -241,6 +241,27 @@ class Dataset[T] private[sql] ( // scalastyle:on println } + private[connect] def explainString(mode: String): String = { + val protoMode = mode.trim.toLowerCase(util.Locale.ROOT) match { + case "simple" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE + case "extended" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_EXTENDED + case "codegen" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_CODEGEN + case "cost" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_COST + case "formatted" => proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_FORMATTED + case _ => throw new IllegalArgumentException("Unsupported explain mode: " + mode) + } + sparkSession + .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN, Some(protoMode)) + .getExplain + .getExplainString + } + + private[connect] def explainString(extended: Boolean): String = if (extended) { + explainString("extended") + } else { + explainString("simple") + } + /** @inheritdoc */ def isLocal: Boolean = sparkSession .analyze(plan, proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ExampleSessionAgnosticSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ExampleSessionAgnosticSuite.scala new file mode 100644 index 0000000000000..9cd9b4677920f --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ExampleSessionAgnosticSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.SparkConf +import org.apache.spark.sql +import org.apache.spark.sql.Row +import org.apache.spark.sql.classic +import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog + +/** + * Example of the session-agnostic test pattern, demonstrating shared, classic-only, and + * connect-only tests colocated in one file. + * + * {{{ + * // Shared tests run against BOTH a classic and a connect session. + * class FooSuite extends sql.SessionQueryTest { + * test("shared") { checkAnswer(sql("SELECT 1"), Row(1)) } + * } + * + * // Classic-only tests. Extend classic.SessionQueryTest (not FooSuite, which already runs + * // the shared tests on classic) so that classic-only APIs are visible. + * class FooClassicSuite extends classic.SessionQueryTest { + * test("classic only") { spark.sql("SELECT 1").queryExecution } + * } + * + * // Connect variant: re-runs FooSuite's shared tests on connect, plus connect-only tests. + * class FooConnectSuite extends FooSuite with connect.SessionQueryTest { + * test("connect only") { assert(spark.range(1).count() == 1) } + * } + * }}} + * + * In real suites the shared and classic-only parts are written in sql/core and only the connect + * variant lives in sql/connect; they are colocated here so the whole pattern is on one page. + */ +class ExampleSessionAgnosticSuite extends sql.SessionQueryTest { + + override protected def sparkConf: SparkConf = + super.sparkConf + .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) + .set("spark.sql.defaultCatalog", "testcat") + + test("Example classic/connect-agnostic testcase") { + withTable("t") { + spark.sql(s"CREATE TABLE t (id INT, salary INT) USING foo").collect() + spark.sql(s"INSERT INTO t VALUES (1, 100)").collect() + + val df1 = spark.table("t") + + spark.sql(s"ALTER TABLE t ADD COLUMN new_column INT").collect() + spark.sql(s"INSERT INTO t VALUES (2, 200, -1)").collect() + + val df2 = spark.table("t") + val selfJoin = df1.join(df2, df1("id") === df2("id")) + + // diverging behaviour can be documented via `isConnect` + if (isConnect) { + // Connect re-resolves df1 with the new 3-column schema (id, salary, new_column). + assert( + selfJoin.columns.length == 6, + s"Expected 6 columns (3 + 3) but got: ${selfJoin.columns.mkString(", ")}") + checkAnswer(selfJoin, Seq(Row(1, 100, null, 1, 100, null), Row(2, 200, -1, 2, 200, -1))) + } else { + // Classic: df1 keeps its original 2-column schema (id, salary). + assert( + selfJoin.columns.length == 5, + s"Expected 5 columns (2 + 3) but got: ${selfJoin.columns.mkString(", ")}") + checkAnswer(selfJoin, Seq(Row(1, 100, 1, 100, null), Row(2, 200, 2, 200, -1))) + } + } + } + + test("testcase that uses withConf") { + // since SQLConf is not part of the public API, + // `withConf` can be used to temporarily change the RuntimeConfig. + withConf("spark.sql.charAsVarchar" -> "true") { + withTable("t") { + spark.sql(s"CREATE TABLE t(col CHAR(5)) USING foo") + checkAnswer(spark.sql(s"desc t").selectExpr("data_type"), Seq(Row("varchar(5)"))) + } + } + } +} + +/** + * Example of a classic-only suite. Extends [[classic.SessionQueryTest]], so `spark` is a + * [[classic.SparkSession]] and classic-only APIs are visible. The shared tests already run on + * classic via [[ExampleSessionAgnosticSuite]], so they are not repeated here. + */ +class ExampleClassicSuite extends classic.SessionQueryTest { + + test("classic-only testcase") { + // `spark` is a classic.SparkSession, so classic-only APIs like queryExecution are visible. + val df = spark.sql("SELECT 1") + assert(df.queryExecution.analyzed.output.length == 1) + } +} + +/** + * Connect variant of [[ExampleSessionAgnosticSuite]]: re-runs its shared tests against a connect + * session and adds connect-only tests. + */ +class ExampleSessionAgnosticConnectSuite + extends ExampleSessionAgnosticSuite + with SessionQueryTest { + + test("connect-only testcase") { + // Tests declared in the connect variant run only on connect (not classic). + // Here `spark` is a connect SparkSession. + assert(isConnect) + assert(spark.range(1).count() == 1) + } +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala new file mode 100644 index 0000000000000..2748af9ae9b1e --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.sql.QueryTestSuite + +/** + * Runs [[QueryTestSuite]] tests through a Connect session. + * + * This validates the `FooSuite with connect.SessionQueryTest` pattern: the existing + * [[QueryTestSuite]] tests are inherited unchanged, but execute against a + * [[SparkSession connect.SparkSession]] instead of a classic one. + */ +class QueryTestWithConnectSuite extends QueryTestSuite with SessionQueryTest diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SessionQueryTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SessionQueryTest.scala new file mode 100644 index 0000000000000..89e7ea90f8f49 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SessionQueryTest.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import scala.util.matching.Regex + +import org.apache.spark.sql + +/** + * Overrides test utils to implement 'connect variants' of suites declared in sql/core: + * {{{ + * // in sql/core + * FooSuite extends SessionQueryTest { test("") { ... } } + * + * // in sql/connect + * FooConnectSuite extends FooSuite with connect.SessionQueryTest + * }}} + * + * This trait overrides [[spark]] to use a [[SparkSession connect.SparkSession]], which executes + * via the gRPC API using an in-process connect server. + */ +trait SessionQueryTest extends sql.SessionQueryTest with SparkSessionBinder { + + private val sortOperator: Regex = """\b(?:Photon)?Sort\b""".r + + /** + * Approximates [[sql.SessionQueryTest.isDfSorted]] by inspecting the explain string. + */ + override def isDfSorted(df: sql.DataFrame): Boolean = df match { + case df: DataFrame => sortOperator.unanchored.matches(df.explainString(extended = false)) + case df => super.isDfSorted(df) + } + + override def isConnect: Boolean = true +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala new file mode 100644 index 0000000000000..c43c3019298ac --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import java.util.UUID + +import org.apache.spark.{SparkEnv, SparkFunSuite} +import org.apache.spark.sql +import org.apache.spark.sql.classic +import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.SparkConnectService + +/** + * Provides a [[SparkSession connect.SparkSession]] backed by an in-process gRPC server. Extends + * [[sql.SparkSessionBinder sql.SparkSessionBinder]] (which creates a + * [[classic.SparkSession classic.SparkSession]] and SparkContext), then layers a Connect client + * session on top by starting the gRPC service in-process. + */ +trait SparkSessionBinder extends sql.SparkSessionBinder { self: SparkFunSuite => + + private var _connectSpark: SparkSession = _ + + protected override def spark: SparkSession = _connectSpark + + /** The underlying classic session used by the in-process server. */ + private def classicSpark: classic.SparkSession = super.spark.asInstanceOf[classic.SparkSession] + + override protected def beforeAll(): Unit = { + super.beforeAll() + // Other suites using mocks leave a mess in the global executionManager, + // shut it down so that it's cleared before starting server. + SparkConnectService.executionManager.shutdown() + val prevPort = SparkEnv.get.conf.get(Connect.CONNECT_GRPC_BINDING_PORT) + try { + // set GRPC_BINDING_PORT to 0 so that the server picks a random, freely available port. + SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, 0) + SparkConnectService.start(classicSpark.sparkContext) + } finally { + SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, prevPort) + } + val client = SparkConnectClient + .builder() + .port(SparkConnectService.localPort) + .sessionId(UUID.randomUUID().toString) + .userId("test") + .build() + _connectSpark = SparkSession + .builder() + .client(client) + .create() + } + + override def afterAll(): Unit = { + if (_connectSpark != null) { + _connectSpark.close() + _connectSpark = null + } + SparkConnectService.stop() + super.afterAll() + } +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala new file mode 100644 index 0000000000000..959e7fd899397 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.sql + +/** + * A common trait for test suites or utils that require a connect [[SparkSession]]. Use together + * with e.g. [[SparkSessionBinder]]. + */ +trait SparkSessionProvider extends sql.SparkSessionProvider { + protected override def spark: SparkSession +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala new file mode 100644 index 0000000000000..42b71a49b1f86 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CheckAnswerHelper.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.scalatest.Assertions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.util.SparkErrorUtils + +/** + * Provides [[checkAnswer]] helper for SQL- & DataFrame-API tests. + * + * TODO: should be moved to sql/api together with SessionQueryTestBase + */ +@Experimental +trait CheckAnswerHelper extends Assertions { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + + val analyzedDF = try df catch { + case ae: ExtendedAnalysisException => + if (ae.plan.isDefined) { + fail( + s""" + |Failed to analyze query: $ae + |${ae.plan.get} + | + |${SparkErrorUtils.stackTraceToString(ae)} + |""".stripMargin) + } else { + throw ae + } + } + + getErrorMessageInCheckAnswer(analyzedDF, expectedAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /* + * Note: when moving this to sql/api, implementation should stay in sql/core + * (i.e. only have abstract decl in sql/api) + */ + protected def isDfSorted(df: DataFrame): Boolean = { + df match { + case df: classic.DataFrame => + df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty + case _ => + // isDfSorted should be overridden by connect so that this case can't be reached. + throw new RuntimeException( + s"""Cannot determine whether df is sorted: $df. + |Maybe the suite is missing the connect.SessionQueryTest mixin?""".stripMargin) + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a None will + * be returned. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + private def getErrorMessageInCheckAnswer( + df: DataFrame, + expectedAnswer: Seq[Row]): Option[String] = { + val sparkAnswer = try df.collect().toSeq catch { + case e: Exception => + val errorMessage = + s""" + |Exception thrown while executing query: + |${if (df.isInstanceOf[classic.DataFrame]) { df.queryExecution } else df.toString} + |== Exception == + |$e + |${SparkErrorUtils.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + RowComparisonUtils.sameRows(expectedAnswer, sparkAnswer, isDfSorted(df)).map { results => + s""" + |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env.getOrElse("TZ", "")} + | + |${if (df.isInstanceOf[classic.DataFrame]) { df.queryExecution } else df.toString } + |== Results == + |$results + """.stripMargin + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryCleanupHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryCleanupHelper.scala new file mode 100644 index 0000000000000..4d84036e8e8c2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryCleanupHelper.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.Assertions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.util.SparkErrorUtils + +/** + * Provides [[withTable]], [[withView]], and [[withUserDefinedFunction]] + */ +@Experimental +trait QueryCleanupHelper extends SparkSessionProvider with Assertions { + + /** + * Drops tables `tableNames` after calling `f`. + */ + protected def withTable(tableNames: String*)(f: => Unit): Unit = { + SparkErrorUtils.tryWithSafeFinally(f) { + tableNames.foreach { name => + spark.sql(s"DROP TABLE IF EXISTS $name") + } + } + } + + /** + * Drops views `viewNames` after calling `f`. + */ + protected def withView(viewNames: String*)(f: => Unit): Unit = { + SparkErrorUtils.tryWithSafeFinally(f)( + viewNames.foreach { name => + spark.sql(s"DROP VIEW IF EXISTS $name") + } + ) + } + + protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { + try { + f + } finally { + functions.foreach { case (functionName, isTemporary) => + val withTemporary = if (isTemporary) "TEMPORARY" else "" + spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + assert( + !spark.catalog.functionExists(functionName), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 5a1ea3d9f53cf..02cb61787abc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -30,12 +30,11 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.scalactic.source.Position -import org.scalatest.{Assertions, BeforeAndAfterAll, Suite, Tag} +import org.scalatest.{BeforeAndAfterAll, Suite, Tag} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedAttribute} import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans._ @@ -59,6 +58,8 @@ trait QueryTestBase extends Eventually with BeforeAndAfterAll with SQLTestData + with CheckAnswerHelper + with QueryCleanupHelper with PlanTestBase { self: Suite => /** @@ -156,7 +157,7 @@ trait QueryTestBase * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ - protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + override protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { val analyzedDF = try df catch { case ae: ExtendedAnalysisException => if (ae.plan.isDefined) { @@ -172,9 +173,15 @@ trait QueryTestBase } } - assertEmptyMissingInput(analyzedDF) + if (analyzedDF.isInstanceOf[classic.DataFrame]) { + assertEmptyMissingInput(analyzedDF) - QueryTest.checkAnswer(analyzedDF, expectedAnswer) + SQLExecution.withSQLConfPropagated(analyzedDF.sparkSession) { + analyzedDF.materializedRdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] + } + } + + super.checkAnswer(analyzedDF, expectedAnswer) } protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { @@ -322,25 +329,6 @@ trait QueryTestBase } } - /** - * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). - */ - protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { - try { - f - } catch { - case cause: Throwable => throw cause - } finally { - functions.foreach { case (functionName, isTemporary) => - val withTemporary = if (isTemporary) "TEMPORARY" else "" - spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") - assert( - !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), - s"Function $functionName should have been dropped. But, it still exists.") - } - } - } - /** * Drops temporary view `viewNames` after calling `f`. */ @@ -367,28 +355,6 @@ trait QueryTestBase } } - /** - * Drops table `tableName` after calling `f`. - */ - protected def withTable(tableNames: String*)(f: => Unit): Unit = { - Utils.tryWithSafeFinally(f) { - tableNames.foreach { name => - spark.sql(s"DROP TABLE IF EXISTS $name") - } - } - } - - /** - * Drops view `viewName` after calling `f`. - */ - protected def withView(viewNames: String*)(f: => Unit): Unit = { - Utils.tryWithSafeFinally(f)( - viewNames.foreach { name => - spark.sql(s"DROP VIEW IF EXISTS $name") - } - ) - } - /** * Drops cache `cacheName` after calling `f`. */ @@ -823,7 +789,17 @@ trait QueryTest extends SparkFunSuite with QueryTestBase { } } -object QueryTest extends Assertions { +object QueryTest extends CheckAnswerHelper { + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. + */ + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = { + checkAnswer(df, expectedAnswer, checkToRDD = true) + } + /** * Runs the plan and makes sure the answer matches the expected result. * @@ -831,13 +807,26 @@ object QueryTest extends Assertions { * @param expectedAnswer the expected result in a Seq of Rows. * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. */ - def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Unit = { - getErrorMessageInCheckAnswer(df, expectedAnswer, checkToRDD) match { + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean): Unit = { + if (checkToRDD) { + SQLExecution.withSQLConfPropagated(df.sparkSession) { + df.materializedRdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] + } + } + + super.checkAnswer(df, expectedAnswer) + } + + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer.asScala.toSeq) match { case Some(errorMessage) => fail(errorMessage) case None => } } + override protected def isDfSorted(df: DataFrame): Boolean = + df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty + /** * Runs the plan and makes sure the answer matches the expected result. * If there was exception during the execution or the contents of the DataFrame does not @@ -886,111 +875,27 @@ object QueryTest extends Assertions { } - def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - val converted: Seq[Row] = answer.map(prepareRow) - if (!isSorted) converted.sortBy(_.toString()) else converted - } - - // We need to call prepareRow recursively to handle schemas with struct types. - def prepareRow(row: Row): Row = { - Row.fromSeq(row.toSeq.map { - case null => null - case bd: java.math.BigDecimal => BigDecimal(bd) - // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ - case seq: Seq[_] => seq.map { - case b: java.lang.Byte => b.byteValue - case s: java.lang.Short => s.shortValue - case i: java.lang.Integer => i.intValue - case l: java.lang.Long => l.longValue - case f: java.lang.Float => f.floatValue - case d: java.lang.Double => d.doubleValue - case x => x - } - // Convert array to Seq for easy equality check. - case b: Array[_] => b.toSeq - case r: Row => prepareRow(r) - // SPARK-51349: "null" and null had the same precedence in sorting - case "null" => "__null_string__" - case o => o - }) - } + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = + RowComparisonUtils.prepareAnswer(answer, isSorted) - private def genError( - expectedAnswer: Seq[Row], - sparkAnswer: Seq[Row], - isSorted: Boolean = false): String = { - val getRowType: Option[Row] => String = row => - row.map(row => - if (row.schema == null) { - "struct<>" - } else { - s"${row.schema.catalogString}" - }).getOrElse("struct<>") - - s""" - |== Results == - |${ - sideBySide( - s"== Correct Answer - ${expectedAnswer.size} ==" +: - getRowType(expectedAnswer.headOption) +: - prepareAnswer(expectedAnswer, isSorted).map(_.toString()), - s"== Spark Answer - ${sparkAnswer.size} ==" +: - getRowType(sparkAnswer.headOption) +: - prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n") - } - """.stripMargin - } + def prepareRow(row: Row): Row = RowComparisonUtils.prepareRow(row) def includesRows( expectedRows: Seq[Row], sparkAnswer: Seq[Row]): Option[String] = { if (!prepareAnswer(expectedRows, true).toSet.subsetOf(prepareAnswer(sparkAnswer, true).toSet)) { - return Some(genError(expectedRows, sparkAnswer, true)) + return Some(RowComparisonUtils.genError(expectedRows, sparkAnswer, true)) } None } - def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { - case (null, null) => true - case (null, _) => false - case (_, null) => false - case (a: Array[_], b: Array[_]) => - a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} - case (a: Map[_, _], b: Map[_, _]) => - a.size == b.size && a.keys.forall { aKey => - b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey))) - } - case (a: Iterable[_], b: Iterable[_]) => - a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} - case (a: Product, b: Product) => - compare(a.productIterator.toSeq, b.productIterator.toSeq) - case (a: Row, b: Row) => - compare(a.toSeq, b.toSeq) - // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. - // in some hardware NaN can be represented with different bits, so first check for it - case (a: Double, b: Double) => - a.isNaN && b.isNaN || - java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) - case (a: Float, b: Float) => - a.isNaN && b.isNaN || - java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) - case (a, b) => a == b - } + def compare(obj1: Any, obj2: Any): Boolean = RowComparisonUtils.compare(obj1, obj2) def sameRows( expectedAnswer: Seq[Row], sparkAnswer: Seq[Row], - isSorted: Boolean = false): Option[String] = { - if (!compare(prepareAnswer(expectedAnswer, isSorted), prepareAnswer(sparkAnswer, isSorted))) { - return Some(genError(expectedAnswer, sparkAnswer, isSorted)) - } - None - } + isSorted: Boolean = false): Option[String] = + RowComparisonUtils.sameRows(expectedAnswer, sparkAnswer, isSorted) def compareAnswers( sparkAnswer: Seq[Row], @@ -1054,13 +959,6 @@ object QueryTest extends Assertions { } } - def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = { - getErrorMessageInCheckAnswer(df, expectedAnswer.asScala.toSeq) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } - } - def withQueryExecutionsCaptured(spark: SparkSession)(thunk: => Unit): Seq[QueryExecution] = { var capturedQueryExecutions = Seq.empty[QueryExecution] @@ -1211,7 +1109,7 @@ object QueryTest extends Assertions { } -class QueryTestSuite extends test.SharedSparkSession { +class QueryTestSuite extends QueryTest with SparkSessionBinder { test("SPARK-16940: checkAnswer should raise TestFailedException for wrong results") { intercept[org.scalatest.exceptions.TestFailedException] { checkAnswer(sql("SELECT 1"), Row(2) :: Nil) @@ -1223,4 +1121,19 @@ class QueryTestSuite extends test.SharedSparkSession { "from range(2)"), Seq(Row(Row(null)), Row(Row("null")))) } + + test("checkAnswer demands correct result order for ordered queries") { + val e = intercept[org.scalatest.exceptions.TestFailedException] { + checkAnswer( + sql("SELECT col1 FROM VALUES 1, 2, 1, 3 ORDER BY col1"), + Seq(Row(3), Row(1), Row(1), Row(2))) + } + assert(e.getMessage().contains("Results do not match for query")) + } + + test("checkAnswer ignores result order for unordered queries") { + checkAnswer( + sql("SELECT col1 FROM VALUES 1, 2, 1, 3"), + Seq(Row(3), Row(1), Row(1), Row(2))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowComparisonUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowComparisonUtils.scala new file mode 100644 index 0000000000000..3b65af522a753 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowComparisonUtils.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.util.SparkStringUtils + +/** + * Pure comparison helpers shared by [[CheckAnswerHelper]] and [[QueryTest]]. + */ +private[sql] object RowComparisonUtils { + + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case bd: java.math.BigDecimal => BigDecimal(bd) + // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ + case seq: Seq[_] => seq.map { + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + // SPARK-51349: "null" and null had the same precedence in sorting + case "null" => "__null_string__" + case o => o + }) + } + + def genError( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): String = { + val getRowType: Option[Row] => String = row => + row.map(row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }).getOrElse("struct<>") + + s""" + |== Results == + |${ + SparkStringUtils.sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + getRowType(expectedAnswer.headOption) +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Spark Answer - ${sparkAnswer.size} ==" +: + getRowType(sparkAnswer.headOption) +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n") + } + """.stripMargin + } + + def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Map[_, _], b: Map[_, _]) => + a.size == b.size && a.keys.forall { aKey => + b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey))) + } + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Product, b: Product) => + compare(a.productIterator.toSeq, b.productIterator.toSeq) + case (a: Row, b: Row) => + compare(a.toSeq, b.toSeq) + // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. + // in some hardware NaN can be represented with different bits, so first check for it + case (a: Double, b: Double) => + a.isNaN && b.isNaN || + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + case (a: Float, b: Float) => + a.isNaN && b.isNaN || + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) + case (a, b) => a == b + } + + def sameRows( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): Option[String] = { + if (!compare(prepareAnswer(expectedAnswer, isSorted), prepareAnswer(sparkAnswer, isSorted))) { + return Some(genError(expectedAnswer, sparkAnswer, isSorted)) + } + None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTest.scala new file mode 100644 index 0000000000000..472533d13c5a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTest.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite + +/** + * Provides connect-compatible test utils to write suites that have 'connect variants': + * {{{ + * // in sql/core + * FooSuite extends SessionQueryTest { test("") { ... } } + * + * // in sql/connect + * FooConnectSuite extends connect.SessionQueryTest + * }}} + * + * While this trait internally uses a [[classic.SparkSession]] when executing tests, + * it is exposed as a [[SparkSession sql.SparkSession]] to allow for overriding on the connect side. + * + * For classic-specific tests, use [[classic.SessionQueryTest]]. + * + * For example usage, see the `ExampleSessionAgnosticSuite` example suites in sql/connect. + */ +trait SessionQueryTest + extends SparkFunSuite + with SessionQueryTestBase + with SparkSessionBinder { + + override def isConnect: Boolean = false +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTestBase.scala new file mode 100644 index 0000000000000..7cb77d414ec49 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionQueryTestBase.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +// scalastyle:off funsuite +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.CheckErrorHelper +// scalastyle:on + +/** + * TODO should be moved to sql/api + * + * base for fully sql/core independent tests, i.e. this trait could be moved to sql/api and then + * used in sql/connect/client. + */ +trait SessionQueryTestBase + extends AnyFunSuite + with SparkSessionProvider + with CheckAnswerHelper + with CheckErrorHelper + with QueryCleanupHelper { + + /** + * Sets all configurations specified in `pairs`, calls `f`, and then restores all configurations. + * + * Use this instead of `withSQLConf` as [[internal.SQLConf SQLConf]] is not part of Spark's public + * API. + */ + protected def withConf[T](pairs: (String, String)*)(f: => T): T = { + val (keys, values) = pairs.unzip + val currentValues = keys.map { key => + if (spark.conf.contains(key)) { + Some(spark.conf.get(key)) + } else { + None + } + } + keys.lazyZip(values).foreach { (k, v) => + spark.conf.set(k, v) + } + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) + } + } + } + + /** + * Whether the bound session is a Spark Connect session (`false` for classic), so that tests can + * handle and document session-specific behaviour. + * + * {{{ + * test(...) { + * val df = // query with connect-specific behaviour + * if (isConnect) { + * checkError(...) + * } else { + * checkAnswer(df, ...) + * } + * } + * }}} + */ + def isConnect: Boolean +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala new file mode 100644 index 0000000000000..017e35ff4ce4b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf, SparkFunSuite} +import org.apache.spark.annotation.Experimental +import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.TestSparkSession + +/** + * Provides a [[spark]] implementation by creating a [[classic.SparkSession]]. + * + * Counterpart to [[SparkSessionProvider]], used in [[org.apache.spark.sql.test.SharedSparkSession]] + */ +trait SparkSessionBinder extends SparkSessionBinderBase { self: SparkFunSuite => + + /** + * Suites extending this trait are sharing resources (e.g. SparkSession) in their + * tests. This trait initializes the spark session in its [[beforeAll()]] implementation before + * the automatic thread snapshot is performed, so the audit code could fail to report threads + * leaked by that shared session. + * + * The behavior is overridden here to take the snapshot before the spark session is initialized. + */ + override protected val enableAutoThreadAudit = false + + protected override def beforeAll(): Unit = { + doThreadPreAudit() + super.beforeAll() + } + + protected override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + doThreadPostAudit() + } + } +} + +/** + * [[SparkSessionBinderBase]] is needed for now as + * [[test.SharedSparkSessionBase SharedSparkSessionBase]] is still used by e.g. + * [[test.GenericWordSpecSuite]]. + * + * This Base might be merged into [[SparkSessionBinder]] once it is not required anymore. + * + * TODO: migrate SharedSparkSessionBase users so this can be removed + */ +@Experimental +trait SparkSessionBinderBase + extends SparkSessionProvider + with BeforeAndAfterEach + with BeforeAndAfterAll + with Eventually { self: Suite => + + protected def sparkConf: SparkConf = { + val conf = new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set(UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) + // Disable ConvertToLocalRelation for better test coverage. Test cases built on + // LocalRelation will exercise the optimization rules better by disabling it as + // this rule may potentially block testing of other optimization rules such as + // ConstantPropagation etc. + .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) + conf.set( + StaticSQLConf.WAREHOUSE_PATH, + conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) + conf.set(StaticSQLConf.LOAD_SESSION_EXTENSIONS_FROM_CLASSPATH, false) + conf.set(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD, + sys.env.getOrElse("SPARK_TEST_SQL_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD", + StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) + conf.set(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD, + sys.env.getOrElse("SPARK_TEST_SQL_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD", + StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: classic.SparkSession = null + + protected override def spark: SparkSession = _spark + + /** + * The [[SQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: classic.SparkSession = { + classic.SparkSession.cleanupAnyExistingSession() + new TestSparkSession(sparkConf) + } + + protected def sqlConf: SQLConf = _spark.sessionState.conf + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + _spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds), interval(2.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala new file mode 100644 index 0000000000000..36a52e16314cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import scala.language.implicitConversions + +import org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.FilterExec + +/** + * Extends [[org.apache.spark.sql.QueryTest sql.QueryTest]] to provide classic-only helpers. + */ +trait QueryTest extends sql.QueryTest with SparkSessionProvider { + + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val withoutFilters = df.queryExecution.executedPlan.transform { + case FilterExec(_, child) => child + } + + spark.internalCreateDataFrame(withoutFilters.execute(), schema) + } + + /** + * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier + * way to construct `DataFrame` directly out of local data without relying on implicits. + */ + protected implicit override def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + Dataset.ofRows(spark, plan) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SessionQueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SessionQueryTest.scala new file mode 100644 index 0000000000000..d56146b05d23b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SessionQueryTest.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql + +/** + * Override of [[sql.SessionQueryTest]] that provides [[SparkSession classic.SparkSession]]. + * + * Can be used to declare classic-specific tests: + * {{{ + * class FooSuite extends sql.SessionQueryTest { + * // shared classic/connect-agnostic testcases + * } + * + * // no need to extend FooSuite as sql.SessionQueryTest + * // already executes shared tests via classic internally. + * class FooClassicSuite extends classic.SessionQueryTest { + * test("classic-only test") { + * // classic-only APIs are visible here + * spark.sessionState.conf + * } + * } + * }}} + */ +trait SessionQueryTest extends sql.SessionQueryTest with SparkSessionBinder + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala new file mode 100644 index 0000000000000..2f79876d841d8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.{sql, SparkFunSuite} + +/** + * Overrides [[spark]] to provide a [[SparkSession classic.SparkSession]] + */ +trait SparkSessionBinder extends sql.SparkSessionBinder { self: SparkFunSuite => + override protected def spark: SparkSession = super.spark.asInstanceOf[SparkSession] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala new file mode 100644 index 0000000000000..77de0db4bf68b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql + +trait SparkSessionProvider extends sql.SparkSessionProvider { + override protected def spark: SparkSession +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index b20b6d397fd17..5c2bc6829ea59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.classic import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{SchemaColumnConvertNotSupportedException, SQLHadoopMapReduceCommitProtocol} import org.apache.spark.sql.execution.datasources.parquet.TestingUDT._ @@ -37,14 +38,15 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.functions.struct import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -abstract class ParquetQuerySuite extends ParquetTest with SharedSparkSession { +abstract class ParquetQuerySuite extends ParquetTest + with QueryTest + with classic.SparkSessionBinder { import testImplicits._ test("simple select queries") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index fb26d3311ebef..47e143339f406 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -19,40 +19,12 @@ package org.apache.spark.sql.test import scala.concurrent.duration._ -import org.scalatest.{BeforeAndAfterEach, Suite} -import org.scalatest.concurrent.Eventually +import org.scalatest.Suite -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK -import org.apache.spark.sql.{classic, QueryTest, QueryTestBase, SparkSession, SparkSessionProvider, SQLContext} -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode -import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.{QueryTest, QueryTestBase, SparkSessionBinderBase} +import org.apache.spark.sql.classic -trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { - - /** - * Suites extending [[SharedSparkSession]] are sharing resources (e.g. SparkSession) in their - * tests. That trait initializes the spark session in its [[beforeAll()]] implementation before - * the automatic thread snapshot is performed, so the audit code could fail to report threads - * leaked by that shared session. - * - * The behavior is overridden here to take the snapshot before the spark session is initialized. - */ - override protected val enableAutoThreadAudit = false - - protected override def beforeAll(): Unit = { - doThreadPreAudit() - super.beforeAll() - } - - protected override def afterAll(): Unit = { - try { - super.afterAll() - } finally { - doThreadPostAudit() - } - } +trait SharedSparkSession extends QueryTest with classic.SparkSessionBinder { // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue. @@ -82,124 +54,12 @@ trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { } } + /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -trait SharedSparkSessionBase - extends QueryTestBase - with SparkSessionProvider - with BeforeAndAfterEach - with Eventually { self: Suite => - - protected def sparkConf = { - val conf = new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set(UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) - // Disable ConvertToLocalRelation for better test coverage. Test cases built on - // LocalRelation will exercise the optimization rules better by disabling it as - // this rule may potentially block testing of other optimization rules such as - // ConstantPropagation etc. - .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) - conf.set( - StaticSQLConf.WAREHOUSE_PATH, - conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) - conf.set(StaticSQLConf.LOAD_SESSION_EXTENSIONS_FROM_CLASSPATH, false) - conf.set(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD, - sys.env.getOrElse("SPARK_TEST_SQL_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD", - StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) - conf.set(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD, - sys.env.getOrElse("SPARK_TEST_SQL_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD", - StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected override def spark: classic.SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - classic.SparkSession.cleanupAnyExistingSession() - new TestSparkSession(sparkConf) - } +trait SharedSparkSessionBase extends QueryTestBase with SparkSessionBinderBase { self: Suite => - protected def sqlConf: SQLConf = _spark.sessionState.conf - - /** - * Initialize the [[TestSparkSession]]. Generally, this is just called from - * beforeAll; however, in test using styles other than FunSuite, there is - * often code that relies on the session between test group constructs and - * the actual tests, which may need this session. It is purely a semantic - * difference, but semantically, it makes more sense to call - * 'initializeSession' between a 'describe' and an 'it' call than it does to - * call 'beforeAll'. - */ - protected def initializeSession(): Unit = { - if (_spark == null) { - _spark = createSparkSession - } - } - - /** - * Make sure the [[TestSparkSession]] is initialized before any tests are run. - */ - protected override def beforeAll(): Unit = { - initializeSession() - - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - try { - super.afterAll() - } finally { - try { - if (_spark != null) { - try { - _spark.sessionState.catalog.reset() - } finally { - _spark.stop() - _spark = null - } - } - } finally { - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds), interval(2.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } + protected override def spark: classic.SparkSession = + super.spark.asInstanceOf[classic.SparkSession] } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index 47cc9853f754d..172d374385474 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.hive.test import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSessionProvider -import org.apache.spark.sql.classic.SparkSession +import org.apache.spark.sql.classic.{SparkSession, SparkSessionProvider} import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.client.HiveClient