@@ -1479,25 +1479,28 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14791479 // cast() should return null for invalid inputs when ansi mode is disabled
14801480 val df = data.select(col(" a" ), col(" a" ).cast(toType)).orderBy(col(" a" ))
14811481 if (useDataFrameDiff) {
1482- assertDataFrameEquals(df, assertCometNative = ! hasIncompatibleType)
1483- } else if (hasIncompatibleType) {
1484- checkSparkAnswer(df)
1482+ assertDataFrameEqualsWithExceptions(df, assertCometNative = ! hasIncompatibleType)
14851483 } else {
1486- checkSparkAnswerAndOperator(df)
1484+ if (hasIncompatibleType) {
1485+ checkSparkAnswer(df)
1486+ } else {
1487+ checkSparkAnswerAndOperator(df)
1488+ }
14871489 }
14881490
14891491 if (testTry) {
14901492 data.createOrReplaceTempView(" t" )
1491- // try_cast() should always return null for invalid inputs
1492- // not using spark DSL since it `try_cast` is only available from Spark 4x
1493- val df2 =
1494- spark.sql(s " select a, try_cast(a as ${toType.sql}) from t order by a " )
1495- if (useDataFrameDiff) {
1496- assertDataFrameEquals(df2, assertCometNative = ! hasIncompatibleType)
1497- } else if (hasIncompatibleType) {
1493+ // try_cast() should always return null for invalid inputs
1494+ // not using spark DSL since it `try_cast` is only available from Spark 4x
1495+ val df2 = spark.sql(s " select a, try_cast(a as ${toType.sql}) from t order by a " )
1496+ if (hasIncompatibleType) {
14981497 checkSparkAnswer(df2)
14991498 } else {
1500- checkSparkAnswerAndOperator(df2)
1499+ if (useDataFrameDiff) {
1500+ assertDataFrameEqualsWithExceptions(df2, assertCometNative = ! hasIncompatibleType)
1501+ } else {
1502+ checkSparkAnswerAndOperator(df2)
1503+ }
15011504 }
15021505 }
15031506 }
@@ -1510,63 +1513,65 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
15101513
15111514 // cast() should throw exception on invalid inputs when ansi mode is enabled
15121515 val df = data.withColumn(" converted" , col(" a" ).cast(toType))
1513- if (useDataFrameDiff) {
1514- assertDataFrameEquals (df, assertCometNative = ! hasIncompatibleType)
1516+ val res = if (useDataFrameDiff) {
1517+ assertDataFrameEqualsWithExceptions (df, assertCometNative = ! hasIncompatibleType)
15151518 } else {
1516- checkSparkAnswerMaybeThrows(df) match {
1517- case (None , None ) =>
1518- // neither system threw an exception
1519- case (None , Some (e)) =>
1520- throw e
1521- case (Some (e), None ) =>
1522- // Spark failed but Comet succeeded
1523- fail(s " Comet should have failed with ${e.getCause.getMessage}" )
1524- case (Some (sparkException), Some (cometException)) =>
1525- // both systems threw an exception so we make sure they are the same
1526- val sparkMessage =
1527- if (sparkException.getCause != null ) sparkException.getCause.getMessage
1528- else sparkException.getMessage
1529- val cometMessage =
1530- if (cometException.getCause != null ) cometException.getCause.getMessage
1531- else cometException.getMessage
1532- // this if branch should only check decimal to decimal cast and errors when output precision, scale causes overflow.
1533- if (df.schema(" a" ).dataType.typeName.contains(" decimal" ) && toType.typeName
1534- .contains(" decimal" ) && sparkMessage.contains(" cannot be represented as" )) {
1535- assert(cometMessage.contains(" too large to store" ))
1536- } else {
1537- if (CometSparkSessionExtensions .isSpark40Plus) {
1538- // for Spark 4 we expect to sparkException carries the message
1539- assert(sparkMessage.contains(" SQLSTATE" ))
1540- if (sparkMessage.startsWith(" [NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]" )) {
1541- assert(
1542- sparkMessage.replace(" .WITH_SUGGESTION] " , " ]" ).startsWith(cometMessage))
1543- } else if (cometMessage.startsWith(" [CAST_INVALID_INPUT]" ) || cometMessage
1544- .startsWith(" [CAST_OVERFLOW]" )) {
1545- assert(
1546- sparkMessage.startsWith(
1547- cometMessage
1548- .replace(
1549- " If necessary set \" spark.sql.ansi.enabled\" to \" false\" to bypass this error." ,
1550- " " )))
1551- } else {
1552- assert(sparkMessage.startsWith(cometMessage))
1553- }
1519+ checkSparkAnswerMaybeThrows(df)
1520+ }
1521+ res match {
1522+ case (None , None ) =>
1523+ // neither system threw an exception
1524+ case (None , Some (e)) =>
1525+ throw e
1526+ case (Some (e), None ) =>
1527+ // Spark failed but Comet succeeded
1528+ fail(s " Comet should have failed with ${e.getCause.getMessage}" )
1529+ case (Some (sparkException), Some (cometException)) =>
1530+ // both systems threw an exception so we make sure they are the same
1531+ val sparkMessage =
1532+ if (sparkException.getCause != null ) sparkException.getCause.getMessage
1533+ else sparkException.getMessage
1534+ val cometMessage =
1535+ if (cometException.getCause != null ) cometException.getCause.getMessage
1536+ else cometException.getMessage
1537+ // this if branch should only check decimal to decimal cast and errors when output precision, scale causes overflow.
1538+ if (df.schema(" a" ).dataType.typeName.contains(" decimal" ) && toType.typeName
1539+ .contains(" decimal" ) && sparkMessage.contains(" cannot be represented as" )) {
1540+ assert(cometMessage.contains(" too large to store" ))
1541+ } else {
1542+ if (CometSparkSessionExtensions .isSpark40Plus) {
1543+ // for Spark 4 we expect to sparkException carries the message
1544+ assert(sparkMessage.contains(" SQLSTATE" ))
1545+ if (sparkMessage.startsWith(" [NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION]" )) {
1546+ assert(
1547+ sparkMessage.replace(" .WITH_SUGGESTION] " , " ]" ).startsWith(cometMessage))
1548+ } else if (cometMessage.startsWith(" [CAST_INVALID_INPUT]" ) || cometMessage
1549+ .startsWith(" [CAST_OVERFLOW]" )) {
1550+ assert(
1551+ sparkMessage.startsWith(
1552+ cometMessage
1553+ .replace(
1554+ " If necessary set \" spark.sql.ansi.enabled\" to \" false\" to bypass this error." ,
1555+ " " )))
15541556 } else {
1555- // for Spark 3.4 we expect to reproduce the error message exactly
1556- assert(cometMessage == sparkMessage)
1557+ assert(sparkMessage.startsWith(cometMessage))
15571558 }
1559+ } else {
1560+ // for Spark 3.4 we expect to reproduce the error message exactly
1561+ assert(cometMessage == sparkMessage)
15581562 }
1559- }
1563+ }
15601564 }
1565+ }
15611566
1562- // try_cast() should always return null for invalid inputs
1563- if (testTry) {
1564- data.createOrReplaceTempView(" t" )
1565- val df2 =
1566- spark.sql( s " select a, try_cast(a as ${toType.sql} ) from t order by a " )
1567- if (useDataFrameDiff) {
1568- assertDataFrameEquals(df2, assertCometNative = ! hasIncompatibleType)
1569- } else if (hasIncompatibleType) {
1567+ // try_cast() should always return null for invalid inputs
1568+ if (testTry) {
1569+ data.createOrReplaceTempView(" t" )
1570+ val df2 = spark.sql( s " select a, try_cast(a as ${toType.sql} ) from t order by a " )
1571+ if (useDataFrameDiff) {
1572+ assertDataFrameEqualsWithExceptions(df2, assertCometNative = ! hasIncompatibleType)
1573+ } else {
1574+ if (hasIncompatibleType) {
15701575 checkSparkAnswer(df2)
15711576 } else {
15721577 checkSparkAnswerAndOperator(df2)
0 commit comments