Skip to content

Commit 0546da6

Browse files
committed
fix_ansi_support_when_non_using_collect
1 parent e64e194 commit 0546da6

2 files changed

Lines changed: 97 additions & 81 deletions

File tree

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 68 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,25 +1479,28 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14791479
// cast() should return null for invalid inputs when ansi mode is disabled
14801480
val df = data.select(col("a"), col("a").cast(toType)).orderBy(col("a"))
14811481
if (useDataFrameDiff) {
1482-
assertDataFrameEquals(df, assertCometNative = !hasIncompatibleType)
1483-
} else if (hasIncompatibleType) {
1484-
checkSparkAnswer(df)
1482+
assertDataFrameEqualsWithExceptions(df, assertCometNative = !hasIncompatibleType)
14851483
} else {
1486-
checkSparkAnswerAndOperator(df)
1484+
if (hasIncompatibleType) {
1485+
checkSparkAnswer(df)
1486+
} else {
1487+
checkSparkAnswerAndOperator(df)
1488+
}
14871489
}
14881490

14891491
if (testTry) {
14901492
data.createOrReplaceTempView("t")
1491-
// try_cast() should always return null for invalid inputs
1492-
// not using spark DSL since it `try_cast` is only available from Spark 4x
1493-
val df2 =
1494-
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
1495-
if (useDataFrameDiff) {
1496-
assertDataFrameEquals(df2, assertCometNative = !hasIncompatibleType)
1497-
} else if (hasIncompatibleType) {
1493+
// try_cast() should always return null for invalid inputs
1494+
// not using spark DSL since it `try_cast` is only available from Spark 4x
1495+
val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
1496+
if (hasIncompatibleType) {
14981497
checkSparkAnswer(df2)
14991498
} else {
1500-
checkSparkAnswerAndOperator(df2)
1499+
if (useDataFrameDiff) {
1500+
assertDataFrameEqualsWithExceptions(df2, assertCometNative = !hasIncompatibleType)
1501+
} else {
1502+
checkSparkAnswerAndOperator(df2)
1503+
}
15011504
}
15021505
}
15031506
}
@@ -1510,63 +1513,65 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
15101513

15111514
// cast() should throw exception on invalid inputs when ansi mode is enabled
15121515
val df = data.withColumn("converted", col("a").cast(toType))
1513-
if (useDataFrameDiff) {
1514-
assertDataFrameEquals(df, assertCometNative = !hasIncompatibleType)
1516+
val res = if (useDataFrameDiff) {
1517+
assertDataFrameEqualsWithExceptions(df, assertCometNative = !hasIncompatibleType)
15151518
} else {
1516-
checkSparkAnswerMaybeThrows(df) match {
1517-
case (None, None) =>
1518-
// neither system threw an exception
1519-
case (None, Some(e)) =>
1520-
throw e
1521-
case (Some(e), None) =>
1522-
// Spark failed but Comet succeeded
1523-
fail(s"Comet should have failed with ${e.getCause.getMessage}")
1524-
case (Some(sparkException), Some(cometException)) =>
1525-
// both systems threw an exception so we make sure they are the same
1526-
val sparkMessage =
1527-
if (sparkException.getCause != null) sparkException.getCause.getMessage
1528-
else sparkException.getMessage
1529-
val cometMessage =
1530-
if (cometException.getCause != null) cometException.getCause.getMessage
1531-
else cometException.getMessage
1532-
// this if branch should only check decimal to decimal cast and errors when output precision, scale causes overflow.
1533-
if (df.schema("a").dataType.typeName.contains("decimal") && toType.typeName
1534-
.contains("decimal") && sparkMessage.contains("cannot be represented as")) {
1535-
assert(cometMessage.contains("too large to store"))
1536-
} else {
1537-
if (CometSparkSessionExtensions.isSpark40Plus) {
1538-
// for Spark 4 we expect to sparkException carries the message
1539-
assert(sparkMessage.contains("SQLSTATE"))
1540-
if (sparkMessage.startsWith("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]")) {
1541-
assert(
1542-
sparkMessage.replace(".WITH_SUGGESTION] ", "]").startsWith(cometMessage))
1543-
} else if (cometMessage.startsWith("[CAST_INVALID_INPUT]") || cometMessage
1544-
.startsWith("[CAST_OVERFLOW]")) {
1545-
assert(
1546-
sparkMessage.startsWith(
1547-
cometMessage
1548-
.replace(
1549-
"If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.",
1550-
"")))
1551-
} else {
1552-
assert(sparkMessage.startsWith(cometMessage))
1553-
}
1519+
checkSparkAnswerMaybeThrows(df)
1520+
}
1521+
res match {
1522+
case (None, None) =>
1523+
// neither system threw an exception
1524+
case (None, Some(e)) =>
1525+
throw e
1526+
case (Some(e), None) =>
1527+
// Spark failed but Comet succeeded
1528+
fail(s"Comet should have failed with ${e.getCause.getMessage}")
1529+
case (Some(sparkException), Some(cometException)) =>
1530+
// both systems threw an exception so we make sure they are the same
1531+
val sparkMessage =
1532+
if (sparkException.getCause != null) sparkException.getCause.getMessage
1533+
else sparkException.getMessage
1534+
val cometMessage =
1535+
if (cometException.getCause != null) cometException.getCause.getMessage
1536+
else cometException.getMessage
1537+
// this if branch should only check decimal to decimal cast and errors when output precision, scale causes overflow.
1538+
if (df.schema("a").dataType.typeName.contains("decimal") && toType.typeName
1539+
.contains("decimal") && sparkMessage.contains("cannot be represented as")) {
1540+
assert(cometMessage.contains("too large to store"))
1541+
} else {
1542+
if (CometSparkSessionExtensions.isSpark40Plus) {
1543+
// for Spark 4 we expect to sparkException carries the message
1544+
assert(sparkMessage.contains("SQLSTATE"))
1545+
if (sparkMessage.startsWith("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]")) {
1546+
assert(
1547+
sparkMessage.replace(".WITH_SUGGESTION] ", "]").startsWith(cometMessage))
1548+
} else if (cometMessage.startsWith("[CAST_INVALID_INPUT]") || cometMessage
1549+
.startsWith("[CAST_OVERFLOW]")) {
1550+
assert(
1551+
sparkMessage.startsWith(
1552+
cometMessage
1553+
.replace(
1554+
"If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.",
1555+
"")))
15541556
} else {
1555-
// for Spark 3.4 we expect to reproduce the error message exactly
1556-
assert(cometMessage == sparkMessage)
1557+
assert(sparkMessage.startsWith(cometMessage))
15571558
}
1559+
} else {
1560+
// for Spark 3.4 we expect to reproduce the error message exactly
1561+
assert(cometMessage == sparkMessage)
15581562
}
1559-
}
1563+
}
15601564
}
1565+
}
15611566

1562-
// try_cast() should always return null for invalid inputs
1563-
if (testTry) {
1564-
data.createOrReplaceTempView("t")
1565-
val df2 =
1566-
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
1567-
if (useDataFrameDiff) {
1568-
assertDataFrameEquals(df2, assertCometNative = !hasIncompatibleType)
1569-
} else if (hasIncompatibleType) {
1567+
// try_cast() should always return null for invalid inputs
1568+
if (testTry) {
1569+
data.createOrReplaceTempView("t")
1570+
val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
1571+
if (useDataFrameDiff) {
1572+
assertDataFrameEqualsWithExceptions(df2, assertCometNative = !hasIncompatibleType)
1573+
} else {
1574+
if (hasIncompatibleType) {
15701575
checkSparkAnswer(df2)
15711576
} else {
15721577
checkSparkAnswerAndOperator(df2)

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
2424
import scala.concurrent.duration._
2525
import scala.reflect.ClassTag
2626
import scala.reflect.runtime.universe.TypeTag
27-
import scala.util.{Success, Try}
27+
import scala.util.{Failure, Success, Try}
2828

2929
import org.scalatest.BeforeAndAfterEach
3030

@@ -43,7 +43,7 @@ import org.apache.spark.sql.execution._
4343
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
4444
import org.apache.spark.sql.internal._
4545
import org.apache.spark.sql.test._
46-
import org.apache.spark.sql.types.{DecimalType, StructType}
46+
import org.apache.spark.sql.types.{DataType, DecimalType, StructType}
4747

4848
import org.apache.comet._
4949
import org.apache.comet.shims.ShimCometSparkSessionExtensions
@@ -1278,12 +1278,12 @@ abstract class CometTestBase
12781278
}
12791279

12801280
/**
1281-
* Compares Spark and Comet results using exceptAll instead of collect(). This avoids
1282-
* java.sql.Timestamp overflow issues with extreme timestamp values.
1281+
* Compares Spark and Comet results using foreach() and exceptAll() to avoid collect(). Raises /
1282+
* validates right ANSI exception
12831283
*/
1284-
protected def assertDataFrameEquals(
1284+
protected def assertDataFrameEqualsWithExceptions(
12851285
df: => DataFrame,
1286-
assertCometNative: Boolean = true): Unit = {
1286+
assertCometNative: Boolean = true): (Option[Throwable], Option[Throwable]) = {
12871287

12881288
var dfSpark: DataFrame = null
12891289
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
@@ -1296,20 +1296,31 @@ abstract class CometTestBase
12961296
dfSpark.schema == dfComet.schema,
12971297
s"Schema mismatch:\nSpark: ${dfSpark.schema}\nComet: ${dfComet.schema}")
12981298

1299-
val sparkMinusComet = dfSpark.exceptAll(dfComet)
1300-
val cometMinusSpark = dfComet.exceptAll(dfSpark)
1301-
val diffCount1 = sparkMinusComet.count()
1302-
val diffCount2 = cometMinusSpark.count()
1299+
val expected = Try(dfSpark.foreach(_ => ()))
1300+
val actual = Try(dfComet.foreach(_ => ()))
13031301

1304-
if (diffCount1 > 0 || diffCount2 > 0) {
1305-
fail(
1306-
"Results do not match. " +
1307-
s"Rows in Spark but not Comet: $diffCount1. " +
1308-
s"Rows in Comet but not Spark: $diffCount2.")
1309-
}
1302+
(expected, actual) match {
1303+
case (Success(_), Success(_)) =>
1304+
// compare results and confirm that they match
1305+
val sparkMinusComet = dfSpark.exceptAll(dfComet)
1306+
val cometMinusSpark = dfComet.exceptAll(dfSpark)
1307+
val diffCount1 = sparkMinusComet.count()
1308+
val diffCount2 = cometMinusSpark.count()
13101309

1311-
if (assertCometNative) {
1312-
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
1310+
if (diffCount1 > 0 || diffCount2 > 0) {
1311+
fail(
1312+
"Results do not match. " +
1313+
s"Rows in Spark but not Comet: $diffCount1. " +
1314+
s"Rows in Comet but not Spark: $diffCount2.")
1315+
}
1316+
1317+
if (assertCometNative) {
1318+
checkCometOperators(stripAQEPlan(dfComet.queryExecution.executedPlan))
1319+
}
1320+
1321+
(None, None)
1322+
case _ =>
1323+
(expected.failed.toOption, actual.failed.toOption)
13131324
}
13141325
}
13151326
}

0 commit comments

Comments
 (0)