@@ -22,6 +22,7 @@ package org.apache.comet
2222import java .io .File
2323
2424import scala .collection .mutable .ListBuffer
25+ import scala .jdk .CollectionConverters ._
2526import scala .util .Random
2627
2728import org .apache .hadoop .fs .Path
@@ -1465,6 +1466,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14651466 }
14661467 }
14671468 }
1469+
1470+ val nestedType =
1471+ StructType (Seq (StructField (" long_value" , LongType ), StructField (" bool_value" , BooleanType )))
1472+ val structType = StructType (
1473+ Seq (
1474+ StructField (" int_value" , IntegerType ),
1475+ StructField (" string_value" , StringType ),
1476+ StructField (" nested_value" , nestedType)))
1477+ val schema = StructType (Seq (StructField (" a" , structType)))
1478+ val rows = Seq (
1479+ Row (Row (1 , " one" , Row (10L , true ))),
1480+ Row (Row (null , " missing-int" , Row (- 2L , false ))),
1481+ Row (Row (3 , null , null )),
1482+ Row (null ))
1483+
1484+ castTest(spark.createDataFrame(rows.asJava, schema), StringType )
14681485 }
14691486
14701487 test(" cast StructType to StructType" ) {
@@ -1479,6 +1496,44 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14791496 }
14801497 }
14811498 }
1499+
1500+ val fromNestedType = StructType (Seq (StructField (" inner_int" , IntegerType )))
1501+ val fromType = StructType (
1502+ Seq (
1503+ StructField (" long_value" , LongType ),
1504+ StructField (" string_value" , StringType ),
1505+ StructField (" nested_value" , fromNestedType)))
1506+ val toNestedType = StructType (Seq (StructField (" renamed_inner_long" , LongType )))
1507+ val toType = StructType (
1508+ Seq (
1509+ StructField (" renamed_byte" , ByteType ),
1510+ StructField (" renamed_string" , StringType ),
1511+ StructField (" renamed_nested" , toNestedType)))
1512+ val schema = StructType (Seq (StructField (" a" , fromType)))
1513+ val rows = Seq (
1514+ Row (Row (1L , " one" , Row (10 ))),
1515+ Row (Row (127L , null , Row (- 20 ))),
1516+ Row (Row (null , " missing-long" , null )),
1517+ Row (null ))
1518+
1519+ castTest(spark.createDataFrame(rows.asJava, schema), toType)
1520+
1521+ val overflowFromType = StructType (
1522+ Seq (StructField (" long_value" , LongType ), StructField (" string_value" , StringType )))
1523+ val overflowToType = StructType (
1524+ Seq (StructField (" renamed_byte" , ByteType ), StructField (" renamed_string" , StringType )))
1525+ val overflowSchema = StructType (Seq (StructField (" a" , overflowFromType)))
1526+ val overflowRows = Seq (
1527+ Row (Row (1L , " fits" )),
1528+ Row (Row (128L , " too-large" )),
1529+ Row (Row (- 129L , " too-small" )),
1530+ Row (Row (null , " missing-long" )),
1531+ Row (null ))
1532+
1533+ castTest(
1534+ spark.createDataFrame(overflowRows.asJava, overflowSchema),
1535+ overflowToType,
1536+ expectAnsiFailure = true )
14821537 }
14831538
14841539 test(" cast StructType to StructType with different names" ) {
@@ -1564,8 +1619,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
15641619 }
15651620
15661621 test(" cast ArrayType to StringType - float double binary edge cases" ) {
1567- import scala .jdk .CollectionConverters ._
1568-
15691622 def bytes (values : Int * ): Array [Byte ] = values.map(_.toByte).toArray
15701623
15711624 def arrayInput (elementType : DataType , values : Seq [Any ]): DataFrame = {
@@ -1630,6 +1683,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
16301683 DataTypes .TimestampNTZType ,
16311684 BinaryType )
16321685 testArrayCastMatrix(types, ArrayType (_), generateArrays(100 , _))
1686+
1687+ val schema = StructType (Seq (StructField (" a" , ArrayType (LongType ))))
1688+ val rows = Seq (
1689+ Row (Seq [Any ](1L , 127L , null )),
1690+ Row (Seq [Any ](128L )),
1691+ Row (Seq [Any ](- 129L , 0L )),
1692+ Row (Seq .empty[Any ]),
1693+ Row (null ))
1694+
1695+ castTest(
1696+ spark.createDataFrame(rows.asJava, schema),
1697+ ArrayType (ByteType ),
1698+ expectAnsiFailure = true )
16331699 }
16341700
16351701 test(" cast MapType to MapType" ) {
@@ -1639,7 +1705,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
16391705 // the planner routes Map→Map casts into it. The map column must be read
16401706 // natively for the cast to be exercised by Comet, which only happens
16411707 // under the V1 Parquet scan, so we pin USE_V1_SOURCE_LIST=parquet.
1642- import scala .collection .JavaConverters ._
16431708 val schema =
16441709 StructType (Seq (StructField (" a" , MapType (IntegerType , IntegerType ), nullable = true )))
16451710 val rows = Range (0 , 100 ).map { i =>
@@ -1837,7 +1902,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
18371902 }
18381903
18391904 private def generateArrays (rowNum : Int , elementType : DataType ): DataFrame = {
1840- import scala .jdk .CollectionConverters ._
18411905 val schema = StructType (Seq (StructField (" a" , ArrayType (elementType), true )))
18421906 def buildRows (values : Seq [Any ]): Seq [Row ] = {
18431907 Range (0 , rowNum).map { i =>
@@ -1899,7 +1963,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
18991963 }
19001964
19011965 private def generateNestedArrays (rowNum : Int , elementType : DataType ): DataFrame = {
1902- import scala .jdk .CollectionConverters ._
19031966 val schema = StructType (Seq (StructField (" a" , ArrayType (ArrayType (elementType)), true )))
19041967 val innerArrays = generateArrays(rowNum, elementType)
19051968 .collect()
@@ -2214,6 +2277,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
22142277 hasIncompatibleType : Boolean = false ,
22152278 testAnsi : Boolean = true ,
22162279 testTry : Boolean = true ,
2280+ expectAnsiFailure : Boolean = false ,
22172281 useDataFrameDiff : Boolean = false ): Unit = {
22182282
22192283 withTempPath { dir =>
@@ -2261,11 +2325,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
22612325 .select(col(" __row_id" ), col(" a" ), col(" a" ).cast(toType).as(" converted" ))
22622326 .orderBy(col(" __row_id" ))
22632327 .drop(" __row_id" )
2328+ if (expectAnsiFailure) {
2329+ assert(! hasIncompatibleType, " Expected ANSI failures must use Comet native execution" )
2330+ checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
2331+ }
22642332 val res = if (useDataFrameDiff) {
22652333 assertDataFrameEqualsWithExceptions(df, assertCometNative = ! hasIncompatibleType)
22662334 } else {
22672335 checkSparkAnswerMaybeThrows(df)
22682336 }
2337+ if (expectAnsiFailure) {
2338+ assert(res._1.isDefined, " Expected Spark ANSI cast to fail" )
2339+ assert(res._2.isDefined, " Expected Comet ANSI cast to fail" )
2340+ }
22692341 res match {
22702342 case (None , None ) =>
22712343 // neither system threw an exception
0 commit comments