Skip to content

Commit e64e194

Browse files
committed
fix_bool_to_timestamp_support
1 parent 55cae7f commit e64e194

2 files changed

Lines changed: 76 additions & 74 deletions

File tree

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 63 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
168168
}
169169

170170
test("cast BooleanType to TimestampType") {
171-
// Arrow error: Cast error: Casting from Boolean to Timestamp(Microsecond, Some("UTC")) not supported
172-
castTest(generateBools(), DataTypes.TimestampType)
171+
// Spark does not support ANSI or Try mode for Boolean to Timestamp casts
172+
castTest(generateBools(), DataTypes.TimestampType, testAnsi = false, testTry = false)
173173
}
174174

175175
// CAST from ByteType
@@ -437,14 +437,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
437437
}
438438

439439
test("cast LongType to TimestampType") {
440-
// Cast back to long avoids java.sql.Timestamp overflow during collect() for extreme values
441440
compatibleTimezones.foreach { tz =>
442441
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
443-
withTable("t1") {
444-
generateLongs().write.saveAsTable("t1")
445-
val df = spark.sql("select a, cast(cast(a as timestamp) as long) from t1")
446-
checkSparkAnswerAndOperator(df)
447-
}
442+
// Use useDFDiff to avoid collect() which fails on extreme timestamp values
443+
castTest(generateLongs(), DataTypes.TimestampType, useDataFrameDiff = true)
448444
}
449445
}
450446
}
@@ -507,11 +503,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
507503
test("cast FloatType to TimestampType") {
508504
compatibleTimezones.foreach { tz =>
509505
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
510-
withTable("t1") {
511-
generateFloats().write.saveAsTable("t1")
512-
val df = spark.sql("select a, cast(a as timestamp) from t1")
513-
assertDataFrameEquals(df)
514-
}
506+
// Use useDFDiff to avoid collect() which fails on extreme timestamp values
507+
castTest(generateFloats(), DataTypes.TimestampType, useDataFrameDiff = true)
515508
}
516509
}
517510
}
@@ -570,11 +563,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
570563
test("cast DoubleType to TimestampType") {
571564
compatibleTimezones.foreach { tz =>
572565
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
573-
withTable("t1") {
574-
generateDoubles().write.saveAsTable("t1")
575-
val df = spark.sql("select a, cast(a as timestamp) from t1")
576-
assertDataFrameEquals(df)
577-
}
566+
// Use useDFDiff to avoid collect() which fails on extreme timestamp values
567+
castTest(generateDoubles(), DataTypes.TimestampType, useDataFrameDiff = true)
578568
}
579569
}
580570
}
@@ -1479,15 +1469,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14791469
toType: DataType,
14801470
hasIncompatibleType: Boolean = false,
14811471
testAnsi: Boolean = true,
1482-
testTry: Boolean = true): Unit = {
1472+
testTry: Boolean = true,
1473+
useDataFrameDiff: Boolean = false): Unit = {
14831474

14841475
withTempPath { dir =>
14851476
val data = roundtripParquet(input, dir).coalesce(1)
14861477

14871478
withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
14881479
// cast() should return null for invalid inputs when ansi mode is disabled
14891480
val df = data.select(col("a"), col("a").cast(toType)).orderBy(col("a"))
1490-
if (hasIncompatibleType) {
1481+
if (useDataFrameDiff) {
1482+
assertDataFrameEquals(df, assertCometNative = !hasIncompatibleType)
1483+
} else if (hasIncompatibleType) {
14911484
checkSparkAnswer(df)
14921485
} else {
14931486
checkSparkAnswerAndOperator(df)
@@ -1499,7 +1492,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14991492
// not using spark DSL since it `try_cast` is only available from Spark 4x
15001493
val df2 =
15011494
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
1502-
if (hasIncompatibleType) {
1495+
if (useDataFrameDiff) {
1496+
assertDataFrameEquals(df2, assertCometNative = !hasIncompatibleType)
1497+
} else if (hasIncompatibleType) {
15031498
checkSparkAnswer(df2)
15041499
} else {
15051500
checkSparkAnswerAndOperator(df2)
@@ -1515,57 +1510,63 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
15151510

15161511
// cast() should throw exception on invalid inputs when ansi mode is enabled
15171512
val df = data.withColumn("converted", col("a").cast(toType))
1518-
checkSparkAnswerMaybeThrows(df) match {
1519-
case (None, None) =>
1520-
// neither system threw an exception
1521-
case (None, Some(e)) =>
1522-
throw e
1523-
case (Some(e), None) =>
1524-
// Spark failed but Comet succeeded
1525-
fail(s"Comet should have failed with ${e.getCause.getMessage}")
1526-
case (Some(sparkException), Some(cometException)) =>
1527-
// both systems threw an exception so we make sure they are the same
1528-
val sparkMessage =
1529-
if (sparkException.getCause != null) sparkException.getCause.getMessage
1530-
else sparkException.getMessage
1531-
val cometMessage =
1532-
if (cometException.getCause != null) cometException.getCause.getMessage
1533-
else cometException.getMessage
1534-
// this if branch should only check decimal to decimal cast and errors when output precision, scale causes overflow.
1535-
if (df.schema("a").dataType.typeName.contains("decimal") && toType.typeName
1536-
.contains("decimal") && sparkMessage.contains("cannot be represented as")) {
1537-
assert(cometMessage.contains("too large to store"))
1538-
} else {
1539-
if (CometSparkSessionExtensions.isSpark40Plus) {
1540-
// for Spark 4 we expect to sparkException carries the message
1541-
assert(sparkMessage.contains("SQLSTATE"))
1542-
if (sparkMessage.startsWith("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]")) {
1543-
assert(
1544-
sparkMessage.replace(".WITH_SUGGESTION] ", "]").startsWith(cometMessage))
1545-
} else if (cometMessage.startsWith("[CAST_INVALID_INPUT]") || cometMessage
1546-
.startsWith("[CAST_OVERFLOW]")) {
1547-
assert(
1548-
sparkMessage.startsWith(
1549-
cometMessage
1550-
.replace(
1551-
"If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.",
1552-
"")))
1513+
if (useDataFrameDiff) {
1514+
assertDataFrameEquals(df, assertCometNative = !hasIncompatibleType)
1515+
} else {
1516+
checkSparkAnswerMaybeThrows(df) match {
1517+
case (None, None) =>
1518+
// neither system threw an exception
1519+
case (None, Some(e)) =>
1520+
throw e
1521+
case (Some(e), None) =>
1522+
// Spark failed but Comet succeeded
1523+
fail(s"Comet should have failed with ${e.getCause.getMessage}")
1524+
case (Some(sparkException), Some(cometException)) =>
1525+
// both systems threw an exception so we make sure they are the same
1526+
val sparkMessage =
1527+
if (sparkException.getCause != null) sparkException.getCause.getMessage
1528+
else sparkException.getMessage
1529+
val cometMessage =
1530+
if (cometException.getCause != null) cometException.getCause.getMessage
1531+
else cometException.getMessage
1532+
// this if branch should only check decimal to decimal cast and errors when output precision, scale causes overflow.
1533+
if (df.schema("a").dataType.typeName.contains("decimal") && toType.typeName
1534+
.contains("decimal") && sparkMessage.contains("cannot be represented as")) {
1535+
assert(cometMessage.contains("too large to store"))
1536+
} else {
1537+
if (CometSparkSessionExtensions.isSpark40Plus) {
1538+
// for Spark 4 we expect to sparkException carries the message
1539+
assert(sparkMessage.contains("SQLSTATE"))
1540+
if (sparkMessage.startsWith("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]")) {
1541+
assert(
1542+
sparkMessage.replace(".WITH_SUGGESTION] ", "]").startsWith(cometMessage))
1543+
} else if (cometMessage.startsWith("[CAST_INVALID_INPUT]") || cometMessage
1544+
.startsWith("[CAST_OVERFLOW]")) {
1545+
assert(
1546+
sparkMessage.startsWith(
1547+
cometMessage
1548+
.replace(
1549+
"If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.",
1550+
"")))
1551+
} else {
1552+
assert(sparkMessage.startsWith(cometMessage))
1553+
}
15531554
} else {
1554-
assert(sparkMessage.startsWith(cometMessage))
1555+
// for Spark 3.4 we expect to reproduce the error message exactly
1556+
assert(cometMessage == sparkMessage)
15551557
}
1556-
} else {
1557-
// for Spark 3.4 we expect to reproduce the error message exactly
1558-
assert(cometMessage == sparkMessage)
15591558
}
1560-
}
1559+
}
15611560
}
15621561

15631562
// try_cast() should always return null for invalid inputs
15641563
if (testTry) {
15651564
data.createOrReplaceTempView("t")
15661565
val df2 =
15671566
spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by a")
1568-
if (hasIncompatibleType) {
1567+
if (useDataFrameDiff) {
1568+
assertDataFrameEquals(df2, assertCometNative = !hasIncompatibleType)
1569+
} else if (hasIncompatibleType) {
15691570
checkSparkAnswer(df2)
15701571
} else {
15711572
checkSparkAnswerAndOperator(df2)

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,33 +1278,34 @@ abstract class CometTestBase
12781278
}
12791279

12801280
/**
1281-
* Uses except (difference) to find differences without using collect() Checks cometDF and
1282-
* sparkDF including schemas
1281+
* Compares Spark and Comet results using exceptAll instead of collect(). This avoids
1282+
* java.sql.Timestamp overflow issues with extreme timestamp values.
12831283
*/
12841284
protected def assertDataFrameEquals(
12851285
df: => DataFrame,
12861286
assertCometNative: Boolean = true): Unit = {
12871287

1288-
var sparkDf: DataFrame = null
1288+
var dfSpark: DataFrame = null
12891289
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
1290-
sparkDf = datasetOfRows(spark, df.logicalPlan)
1290+
dfSpark = datasetOfRows(spark, df.logicalPlan)
12911291
}
1292-
val cometDf = datasetOfRows(spark, df.logicalPlan)
1292+
val dfComet = datasetOfRows(spark, df.logicalPlan)
12931293

12941294
// Compare schemas
12951295
assert(
1296-
sparkDf.schema == cometDf.schema,
1297-
s"Schema mismatch:\nCorrect Answer: ${sparkDf.schema}\nSpark Answer: ${cometDf.schema}")
1298-
1299-
// Use except (difference) to compare DataFrames without collect() which error on extremely high Timestamp values
1300-
val sparkMinusComet = sparkDf.except(cometDf)
1301-
val cometMinusSpark = cometDf.except(sparkDf)
1296+
dfSpark.schema == dfComet.schema,
1297+
s"Schema mismatch:\nSpark: ${dfSpark.schema}\nComet: ${dfComet.schema}")
13021298

1299+
val sparkMinusComet = dfSpark.exceptAll(dfComet)
1300+
val cometMinusSpark = dfComet.exceptAll(dfSpark)
13031301
val diffCount1 = sparkMinusComet.count()
13041302
val diffCount2 = cometMinusSpark.count()
13051303

13061304
if (diffCount1 > 0 || diffCount2 > 0) {
1307-
fail("DataFrames count doesnt match.\n")
1305+
fail(
1306+
"Results do not match. " +
1307+
s"Rows in Spark but not Comet: $diffCount1. " +
1308+
s"Rows in Comet but not Spark: $diffCount2.")
13081309
}
13091310

13101311
if (assertCometNative) {

0 commit comments

Comments
 (0)