@@ -168,8 +168,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
168168 }
169169
170170 test(" cast BooleanType to TimestampType" ) {
171- // Arrow error: Cast error: Casting from Boolean to Timestamp(Microsecond, Some("UTC")) not supported
172- castTest(generateBools(), DataTypes .TimestampType )
171+ // Spark does not support ANSI or Try mode for Boolean to Timestamp casts
172+ castTest(generateBools(), DataTypes .TimestampType , testAnsi = false , testTry = false )
173173 }
174174
175175 // CAST from ByteType
@@ -437,14 +437,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
437437 }
438438
439439 test(" cast LongType to TimestampType" ) {
440- // Cast back to long avoids java.sql.Timestamp overflow during collect() for extreme values
441440 compatibleTimezones.foreach { tz =>
442441 withSQLConf(SQLConf .SESSION_LOCAL_TIMEZONE .key -> tz) {
443- withTable(" t1" ) {
444- generateLongs().write.saveAsTable(" t1" )
445- val df = spark.sql(" select a, cast(cast(a as timestamp) as long) from t1" )
446- checkSparkAnswerAndOperator(df)
447- }
442+ // Use useDFDiff to avoid collect() which fails on extreme timestamp values
443+ castTest(generateLongs(), DataTypes .TimestampType , useDataFrameDiff = true )
448444 }
449445 }
450446 }
@@ -507,11 +503,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
507503 test(" cast FloatType to TimestampType" ) {
508504 compatibleTimezones.foreach { tz =>
509505 withSQLConf(SQLConf .SESSION_LOCAL_TIMEZONE .key -> tz) {
510- withTable(" t1" ) {
511- generateFloats().write.saveAsTable(" t1" )
512- val df = spark.sql(" select a, cast(a as timestamp) from t1" )
513- assertDataFrameEquals(df)
514- }
506+ // Use useDFDiff to avoid collect() which fails on extreme timestamp values
507+ castTest(generateFloats(), DataTypes .TimestampType , useDataFrameDiff = true )
515508 }
516509 }
517510 }
@@ -570,11 +563,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
570563 test(" cast DoubleType to TimestampType" ) {
571564 compatibleTimezones.foreach { tz =>
572565 withSQLConf(SQLConf .SESSION_LOCAL_TIMEZONE .key -> tz) {
573- withTable(" t1" ) {
574- generateDoubles().write.saveAsTable(" t1" )
575- val df = spark.sql(" select a, cast(a as timestamp) from t1" )
576- assertDataFrameEquals(df)
577- }
566+ // Use useDFDiff to avoid collect() which fails on extreme timestamp values
567+ castTest(generateDoubles(), DataTypes .TimestampType , useDataFrameDiff = true )
578568 }
579569 }
580570 }
@@ -1479,15 +1469,18 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14791469 toType : DataType ,
14801470 hasIncompatibleType : Boolean = false ,
14811471 testAnsi : Boolean = true ,
1482- testTry : Boolean = true ): Unit = {
1472+ testTry : Boolean = true ,
1473+ useDataFrameDiff : Boolean = false ): Unit = {
14831474
14841475 withTempPath { dir =>
14851476 val data = roundtripParquet(input, dir).coalesce(1 )
14861477
14871478 withSQLConf((SQLConf .ANSI_ENABLED .key, " false" )) {
14881479 // cast() should return null for invalid inputs when ansi mode is disabled
14891480 val df = data.select(col(" a" ), col(" a" ).cast(toType)).orderBy(col(" a" ))
1490- if (hasIncompatibleType) {
1481+ if (useDataFrameDiff) {
1482+ assertDataFrameEquals(df, assertCometNative = ! hasIncompatibleType)
1483+ } else if (hasIncompatibleType) {
14911484 checkSparkAnswer(df)
14921485 } else {
14931486 checkSparkAnswerAndOperator(df)
@@ -1499,7 +1492,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14991492// not using spark DSL since it `try_cast` is only available from Spark 4x
15001493 val df2 =
15011494 spark.sql(s " select a, try_cast(a as ${toType.sql}) from t order by a " )
1502- if (hasIncompatibleType) {
1495+ if (useDataFrameDiff) {
1496+ assertDataFrameEquals(df2, assertCometNative = ! hasIncompatibleType)
1497+ } else if (hasIncompatibleType) {
15031498 checkSparkAnswer(df2)
15041499 } else {
15051500 checkSparkAnswerAndOperator(df2)
@@ -1515,57 +1510,63 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
15151510
15161511 // cast() should throw exception on invalid inputs when ansi mode is enabled
15171512 val df = data.withColumn(" converted" , col(" a" ).cast(toType))
1518- checkSparkAnswerMaybeThrows(df) match {
1519- case (None , None ) =>
1520- // neither system threw an exception
1521- case (None , Some (e)) =>
1522- throw e
1523- case (Some (e), None ) =>
1524- // Spark failed but Comet succeeded
1525- fail(s " Comet should have failed with ${e.getCause.getMessage}" )
1526- case (Some (sparkException), Some (cometException)) =>
1527- // both systems threw an exception so we make sure they are the same
1528- val sparkMessage =
1529- if (sparkException.getCause != null ) sparkException.getCause.getMessage
1530- else sparkException.getMessage
1531- val cometMessage =
1532- if (cometException.getCause != null ) cometException.getCause.getMessage
1533- else cometException.getMessage
1534- // this if branch should only check decimal to decimal cast and errors when output precision, scale causes overflow.
1535- if (df.schema(" a" ).dataType.typeName.contains(" decimal" ) && toType.typeName
1536- .contains(" decimal" ) && sparkMessage.contains(" cannot be represented as" )) {
1537- assert(cometMessage.contains(" too large to store" ))
1538- } else {
1539- if (CometSparkSessionExtensions .isSpark40Plus) {
1540- // for Spark 4 we expect to sparkException carries the message
1541- assert(sparkMessage.contains(" SQLSTATE" ))
1542- if (sparkMessage.startsWith(" [NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]" )) {
1543- assert(
1544- sparkMessage.replace(" .WITH_SUGGESTION] " , " ]" ).startsWith(cometMessage))
1545- } else if (cometMessage.startsWith(" [CAST_INVALID_INPUT]" ) || cometMessage
1546- .startsWith(" [CAST_OVERFLOW]" )) {
1547- assert(
1548- sparkMessage.startsWith(
1549- cometMessage
1550- .replace(
1551- " If necessary set \" spark.sql.ansi.enabled\" to \" false\" to bypass this error." ,
1552- " " )))
1513+ if (useDataFrameDiff) {
1514+ assertDataFrameEquals(df, assertCometNative = ! hasIncompatibleType)
1515+ } else {
1516+ checkSparkAnswerMaybeThrows(df) match {
1517+ case (None , None ) =>
1518+ // neither system threw an exception
1519+ case (None , Some (e)) =>
1520+ throw e
1521+ case (Some (e), None ) =>
1522+ // Spark failed but Comet succeeded
1523+ fail(s " Comet should have failed with ${e.getCause.getMessage}" )
1524+ case (Some (sparkException), Some (cometException)) =>
1525+ // both systems threw an exception so we make sure they are the same
1526+ val sparkMessage =
1527+ if (sparkException.getCause != null ) sparkException.getCause.getMessage
1528+ else sparkException.getMessage
1529+ val cometMessage =
1530+ if (cometException.getCause != null ) cometException.getCause.getMessage
1531+ else cometException.getMessage
1532+ // this if branch should only check decimal to decimal cast and errors when output precision, scale causes overflow.
1533+ if (df.schema(" a" ).dataType.typeName.contains(" decimal" ) && toType.typeName
1534+ .contains(" decimal" ) && sparkMessage.contains(" cannot be represented as" )) {
1535+ assert(cometMessage.contains(" too large to store" ))
1536+ } else {
1537+ if (CometSparkSessionExtensions .isSpark40Plus) {
1538+ // for Spark 4 we expect to sparkException carries the message
1539+ assert(sparkMessage.contains(" SQLSTATE" ))
1540+ if (sparkMessage.startsWith(" [NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]" )) {
1541+ assert(
1542+ sparkMessage.replace(" .WITH_SUGGESTION] " , " ]" ).startsWith(cometMessage))
1543+ } else if (cometMessage.startsWith(" [CAST_INVALID_INPUT]" ) || cometMessage
1544+ .startsWith(" [CAST_OVERFLOW]" )) {
1545+ assert(
1546+ sparkMessage.startsWith(
1547+ cometMessage
1548+ .replace(
1549+ " If necessary set \" spark.sql.ansi.enabled\" to \" false\" to bypass this error." ,
1550+ " " )))
1551+ } else {
1552+ assert(sparkMessage.startsWith(cometMessage))
1553+ }
15531554 } else {
1554- assert(sparkMessage.startsWith(cometMessage))
1555+ // for Spark 3.4 we expect to reproduce the error message exactly
1556+ assert(cometMessage == sparkMessage)
15551557 }
1556- } else {
1557- // for Spark 3.4 we expect to reproduce the error message exactly
1558- assert(cometMessage == sparkMessage)
15591558 }
1560- }
1559+ }
15611560 }
15621561
15631562 // try_cast() should always return null for invalid inputs
15641563 if (testTry) {
15651564 data.createOrReplaceTempView(" t" )
15661565 val df2 =
15671566 spark.sql(s " select a, try_cast(a as ${toType.sql}) from t order by a " )
1568- if (hasIncompatibleType) {
1567+ if (useDataFrameDiff) {
1568+ assertDataFrameEquals(df2, assertCometNative = ! hasIncompatibleType)
1569+ } else if (hasIncompatibleType) {
15691570 checkSparkAnswer(df2)
15701571 } else {
15711572 checkSparkAnswerAndOperator(df2)
0 commit comments