Skip to content

Commit 6e81376

Browse files
author
Matthis Gördel
committed
WIP
1 parent 0b427b4 commit 6e81376

18 files changed

Lines changed: 475 additions & 100 deletions
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import java.util.TimeZone
21+
22+
import scala.jdk.CollectionConverters._
23+
import scala.language.implicitConversions
24+
25+
import org.scalatest.Assertions
26+
27+
import org.apache.spark.util.{SparkErrorUtils, SparkStringUtils}
28+
import org.apache.spark.util.ArrayImplicits._
29+
30+
trait CheckAnswerHelper extends Assertions {
31+
32+
/**
33+
* Runs the plan and makes sure the answer matches the expected result.
34+
*
35+
* @param df the DataFrame to be executed
36+
* @param expectedAnswer the expected result in a Seq of Rows.
37+
*/
38+
protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
39+
getErrorMessageInCheckAnswer(df, expectedAnswer) match {
40+
case Some(errorMessage) => fail(errorMessage)
41+
case None =>
42+
}
43+
}
44+
45+
protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
46+
checkAnswer(df, Seq(expectedAnswer))
47+
}
48+
49+
protected def checkAnswer(df: => DataFrame, expectedAnswer: DataFrame): Unit = {
50+
checkAnswer(df, expectedAnswer.collect().toImmutableArraySeq)
51+
}
52+
53+
protected def checkAnswer(df: => DataFrame, expectedAnswer: Array[Row]): Unit = {
54+
checkAnswer(df, expectedAnswer.toImmutableArraySeq)
55+
}
56+
57+
protected def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = {
58+
checkAnswer(df, expectedAnswer.asScala.toSeq)
59+
}
60+
61+
protected def isDfSorted(df: DataFrame): Boolean
62+
63+
/**
64+
* Runs the plan and makes sure the answer matches the expected result.
65+
* If there was exception during the execution or the contents of the DataFrame does not
66+
* match the expected result, an error message will be returned. Otherwise, a None will
67+
* be returned.
68+
*
69+
* @param df the DataFrame to be executed
70+
* @param expectedAnswer the expected result in a Seq of Rows.
71+
*/
72+
private def getErrorMessageInCheckAnswer(
73+
df: DataFrame,
74+
expectedAnswer: Seq[Row]): Option[String] = {
75+
val sparkAnswer = try df.collect().toSeq catch {
76+
case e: Exception =>
77+
val errorMessage =
78+
s"""
79+
|Exception thrown while executing query:
80+
|${df.queryExecution}
81+
|== Exception ==
82+
|$e
83+
|${SparkErrorUtils.stackTraceToString(e)}
84+
""".stripMargin
85+
return Some(errorMessage)
86+
}
87+
88+
sameRows(expectedAnswer, sparkAnswer, isDfSorted(df)).map { results =>
89+
s"""
90+
|Results do not match for query:
91+
|Timezone: ${TimeZone.getDefault}
92+
|Timezone Env: ${sys.env.getOrElse("TZ", "")}
93+
|
94+
|${df.queryExecution}
95+
|== Results ==
96+
|$results
97+
""".stripMargin
98+
}
99+
}
100+
101+
private def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = {
102+
// Converts data to types that we can do equality comparison using Scala collections.
103+
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
104+
// Java's java.math.BigDecimal.compareTo).
105+
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
106+
// equality test.
107+
val converted: Seq[Row] = answer.map(prepareRow)
108+
if (!isSorted) converted.sortBy(_.toString()) else converted
109+
}
110+
111+
// We need to call prepareRow recursively to handle schemas with struct types.
112+
private def prepareRow(row: Row): Row = {
113+
Row.fromSeq(row.toSeq.map {
114+
case null => null
115+
case bd: java.math.BigDecimal => BigDecimal(bd)
116+
// Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+
117+
case seq: Seq[_] => seq.map {
118+
case b: java.lang.Byte => b.byteValue
119+
case s: java.lang.Short => s.shortValue
120+
case i: java.lang.Integer => i.intValue
121+
case l: java.lang.Long => l.longValue
122+
case f: java.lang.Float => f.floatValue
123+
case d: java.lang.Double => d.doubleValue
124+
case x => x
125+
}
126+
// Convert array to Seq for easy equality check.
127+
case b: Array[_] => b.toSeq
128+
case r: Row => prepareRow(r)
129+
// SPARK-51349: "null" and null had the same precedence in sorting
130+
case "null" => "__null_string__"
131+
case o => o
132+
})
133+
}
134+
135+
private def genError(
136+
expectedAnswer: Seq[Row],
137+
sparkAnswer: Seq[Row],
138+
isSorted: Boolean = false): String = {
139+
val getRowType: Option[Row] => String = row =>
140+
row.map(row =>
141+
if (row.schema == null) {
142+
"struct<>"
143+
} else {
144+
s"${row.schema.catalogString}"
145+
}).getOrElse("struct<>")
146+
147+
s"""
148+
|== Results ==
149+
|${
150+
SparkStringUtils.sideBySide(
151+
s"== Correct Answer - ${expectedAnswer.size} ==" +:
152+
getRowType(expectedAnswer.headOption) +:
153+
prepareAnswer(expectedAnswer, isSorted).map(_.toString()),
154+
s"== Spark Answer - ${sparkAnswer.size} ==" +:
155+
getRowType(sparkAnswer.headOption) +:
156+
prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")
157+
}
158+
""".stripMargin
159+
}
160+
161+
private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match {
162+
case (null, null) => true
163+
case (null, _) => false
164+
case (_, null) => false
165+
case (a: Array[_], b: Array[_]) =>
166+
a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)}
167+
case (a: Map[_, _], b: Map[_, _]) =>
168+
a.size == b.size && a.keys.forall { aKey =>
169+
b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey)))
170+
}
171+
case (a: Iterable[_], b: Iterable[_]) =>
172+
a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)}
173+
case (a: Product, b: Product) =>
174+
compare(a.productIterator.toSeq, b.productIterator.toSeq)
175+
case (a: Row, b: Row) =>
176+
compare(a.toSeq, b.toSeq)
177+
// 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0.
178+
// in some hardware NaN can be represented with different bits, so first check for it
179+
case (a: Double, b: Double) =>
180+
a.isNaN && b.isNaN ||
181+
java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b)
182+
case (a: Float, b: Float) =>
183+
a.isNaN && b.isNaN ||
184+
java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b)
185+
case (a, b) => a == b
186+
}
187+
188+
private def sameRows( expectedAnswer: Seq[Row],
189+
sparkAnswer: Seq[Row],
190+
isSorted: Boolean = false): Option[String] = {
191+
if (!compare(prepareAnswer(expectedAnswer, isSorted), prepareAnswer(sparkAnswer, isSorted))) {
192+
return Some(genError(expectedAnswer, sparkAnswer, isSorted))
193+
}
194+
None
195+
}
196+
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.scalatest.Assertions
21+
22+
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.util.SparkErrorUtils
24+
25+
/**
26+
* Provides [[withTable]], [[withView]], and [[withUserDefinedFunction]]
27+
*/
28+
@Experimental
29+
trait QueryCleanupHelper extends SparkSessionProvider with Assertions {
30+
31+
/**
32+
* Drops table `tableName` after calling `f`.
33+
*/
34+
protected def withTable(tableNames: String*)(f: => Unit): Unit = {
35+
SparkErrorUtils.tryWithSafeFinally(f) {
36+
tableNames.foreach { name =>
37+
spark.sql(s"DROP TABLE IF EXISTS $name")
38+
}
39+
}
40+
}
41+
42+
/**
43+
* Drops view `viewName` after calling `f`.
44+
*/
45+
protected def withView(viewNames: String*)(f: => Unit): Unit = {
46+
SparkErrorUtils.tryWithSafeFinally(f)(
47+
viewNames.foreach { name =>
48+
spark.sql(s"DROP VIEW IF EXISTS $name")
49+
}
50+
)
51+
}
52+
53+
protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = {
54+
try {
55+
f
56+
} catch {
57+
case cause: Throwable => throw cause
58+
} finally {
59+
functions.foreach { case (functionName, isTemporary) =>
60+
val withTemporary = if (isTemporary) "TEMPORARY" else ""
61+
spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName")
62+
assert(
63+
!spark.catalog.functionExists(functionName),
64+
s"Function $functionName should have been dropped. But, it still exists.")
65+
}
66+
}
67+
}
68+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
// scalastyle:off funsuite
21+
import org.scalatest.funsuite.AnyFunSuite
22+
// scalastyle:on
23+
24+
trait SessionQueryTestBase
25+
extends AnyFunSuite
26+
with SparkSessionProvider
27+
with CheckAnswerHelper
28+
with QueryCleanupHelper

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionProvider.scala renamed to sql/api/src/test/scala/org/apache/spark/sql/SparkSessionProvider.scala

File renamed without changes.

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/DataSourceV2DataFrameConnectSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.sql.connector.catalog.{CachingInMemoryTableCatalog, InMe
3434
* this class only provides the Connect-specific session, catalog access, and result comparison.
3535
*/
3636
class DataSourceV2DataFrameConnectSuite
37-
extends SparkSessionBinder
37+
extends SessionQueryTest
3838
with DSv2TempViewWithStoredPlanTests
3939
with DSv2RepeatedTableAccessTests
4040
with DSv2IncrementallyConstructedQueryTests
@@ -53,7 +53,6 @@ class DataSourceV2DataFrameConnectSuite
5353
.set("spark.sql.catalog.nullbothidscat.copyOnLoad", "true")
5454

5555
override protected def testPrefix: String = "[connect] "
56-
override protected def isConnect: Boolean = true
5756

5857
override protected def getTableCatalog[C <: TableCatalog: ClassTag](
5958
session: SparkSession,

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ExampleConnectSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ package org.apache.spark.sql.connect
1919

2020
import org.apache.spark.sql
2121

22-
class ExampleConnectSuite extends sql.SparkSessionBinder
22+
class ExampleConnectSuite extends sql.ExampleSuite with SessionQueryTest

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala renamed to sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SessionQueryTest.scala

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,13 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
1817
package org.apache.spark.sql.connect
1918

20-
import org.apache.spark.{sql => sqlApi}
19+
import org.apache.spark.sql
2120

2221
/**
23-
* Extends [[sqlApi.QueryTest]] to provide connect-specific overrides to helpers like
24-
* [[checkAnswer]] that avoid classic-only APIs.
25-
*
26-
* Can be used together with [[SparkSessionBinder connect.SparkSessionBinder]] to create a
27-
* 'connect variant' of a test.
28-
*
29-
* Note: broader use will require more overrides.
22+
* TODO write docstring
3023
*/
31-
trait QueryTest extends sqlApi.QueryTest with SparkSessionProvider {
32-
33-
override protected def checkAnswer(
34-
df: => sqlApi.DataFrame, expectedAnswer: Seq[sqlApi.Row]): Unit = {
35-
val sparkAnswer = df.collect().toSeq
36-
sqlApi.QueryTest.sameRows(expectedAnswer, sparkAnswer) match {
37-
case Some(errorMessage) => fail(errorMessage)
38-
case None =>
39-
}
40-
}
24+
trait SessionQueryTest extends sql.SessionQueryTest with SparkSessionBinder {
25+
override def isDfSorted(df: sql.DataFrame): Boolean = false // TODO
4126
}

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.connect
1919

2020
import java.util.UUID
2121

22-
import org.apache.spark.SparkEnv
22+
import org.apache.spark.{SparkEnv, SparkFunSuite}
2323
import org.apache.spark.sql
2424
import org.apache.spark.sql.classic
2525
import org.apache.spark.sql.connect.client.SparkConnectClient
@@ -31,15 +31,8 @@ import org.apache.spark.sql.connect.service.SparkConnectService
3131
* Extends [[sql.SparkSessionBinder sql.SparkSessionBinder]] (which creates a
3232
* [[classic.SparkSession classic.SparkSession]] and SparkContext), then layers a Connect client
3333
* session on top by starting the gRPC service in-process.
34-
*
35-
* Mix in this trait to exercise existing sql/core test suites through the Connect path:
36-
* {{{
37-
* class FooWithConnectSuite
38-
* extends FooSuite
39-
* with connect.SparkSessionBinder
40-
* }}}
4134
*/
42-
trait SparkSessionBinder extends sql.SparkSessionBinder with QueryTest {
35+
trait SparkSessionBinder extends sql.SparkSessionBinder { self: SparkFunSuite =>
4336

4437
private var _connectSpark: SparkSession = _
4538

0 commit comments

Comments
 (0)