@@ -21,6 +21,7 @@ package org.apache.comet
2121
2222import java .io .File
2323
24+ import scala .collection .mutable .ListBuffer
2425import scala .util .Random
2526import scala .util .matching .Regex
2627
@@ -30,10 +31,11 @@ import org.apache.spark.sql.catalyst.expressions.Cast
3031import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
3132import org .apache .spark .sql .functions .col
3233import org .apache .spark .sql .internal .SQLConf
33- import org .apache .spark .sql .types .{ArrayType , BinaryType , BooleanType , ByteType , DataType , DataTypes , DecimalType , IntegerType , LongType , ShortType , StringType , StructField , StructType }
34+ import org .apache .spark .sql .types .{ArrayType , BinaryType , BooleanType , ByteType , DataType , DataTypes , DecimalType , DoubleType , FloatType , IntegerType , LongType , ShortType , StringType , StructField , StructType }
3435
3536import org .apache .comet .CometSparkSessionExtensions .isSpark40Plus
3637import org .apache .comet .expressions .{CometCast , CometEvalMode }
38+ import org .apache .comet .rules .CometScanTypeChecker
3739import org .apache .comet .serde .Compatible
3840
3941class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
@@ -1035,7 +1037,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10351037 }
10361038
10371039 test(" cast ArrayType to StringType" ) {
1038- sql(" set spark.comet.explainFallback.enabled=true" )
1040+ val cometScanTypeChecker = CometScanTypeChecker (conf.getConfString(CometConf .COMET_NATIVE_SCAN_IMPL .key))
1041+ val scanImpl = conf.getConfString(CometConf .COMET_NATIVE_SCAN_IMPL .key)
1042+ val hasIncompatibleType = (dt : DataType ) => cometScanTypeChecker.isTypeSupported(dt, scanImpl, ListBuffer .empty)
10391043 Seq (
10401044 BooleanType ,
10411045 StringType ,
@@ -1046,9 +1050,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10461050 // FloatType,
10471051 // DoubleType,
10481052 DecimalType (10 , 2 ),
1049- BinaryType ).foreach { tpe =>
1050- val input = generateArrays(100 , tpe)
1051- castTest(input, StringType )
1053+ DecimalType (38 ,18 ),
1054+ BinaryType ).foreach { dt =>
1055+ val input = generateArrays(100 , dt)
1056+ castTest(input, StringType , hasIncompatibleType(input.schema))
10521057 }
10531058 }
10541059
0 commit comments