@@ -24,13 +24,12 @@ import scala.util.Random
2424import org .apache .hadoop .fs .Path
2525import org .apache .spark .sql .CometTestBase
2626import org .apache .spark .sql .catalyst .expressions .{ArrayAppend , ArrayDistinct , ArrayExcept , ArrayInsert , ArrayIntersect , ArrayJoin , ArrayRepeat , ArraysOverlap , ArrayUnion }
27+ import org .apache .spark .sql .catalyst .expressions .{ArrayContains , ArrayRemove , GetArrayItem }
2728import org .apache .spark .sql .execution .adaptive .AdaptiveSparkPlanHelper
2829import org .apache .spark .sql .functions ._
2930import org .apache .spark .sql .internal .SQLConf
3031import org .apache .spark .sql .types .ArrayType
3132
32- import org .apache .spark .sql .catalyst .expressions .{ArrayContains , ArrayRemove , GetArrayItem }
33-
3433import org .apache .comet .CometSparkSessionExtensions .{isSpark35Plus , isSpark40Plus }
3534import org .apache .comet .DataTypeSupport .isComplexType
3635import org .apache .comet .serde .{CometArrayExcept , CometArrayRemove , CometArrayReverse , CometFlatten }
@@ -40,52 +39,56 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
4039
4140 test(" array_remove - integer" ) {
4241 withSQLConf(CometConf .getExprAllowIncompatConfigKey(classOf [ArrayRemove ]) -> " true" ) {
43- Seq (true , false ).foreach { dictionaryEnabled =>
44- withTempView(" t1" ) {
45- withTempDir { dir =>
46- val path = new Path (dir.toURI.toString, " test.parquet" )
47- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000 )
48- spark.read.parquet(path.toString).createOrReplaceTempView(" t1" )
49- checkSparkAnswerAndOperator(
50- sql(" SELECT array_remove(array(_2, _3,_4), _2) from t1 where _2 is null" ))
51- checkSparkAnswerAndOperator(
52- sql(" SELECT array_remove(array(_2, _3,_4), _3) from t1 where _3 is not null" ))
53- checkSparkAnswerAndOperator(sql(
54- " SELECT array_remove(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1" ))
42+ Seq (true , false ).foreach { dictionaryEnabled =>
43+ withTempView(" t1" ) {
44+ withTempDir { dir =>
45+ val path = new Path (dir.toURI.toString, " test.parquet" )
46+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000 )
47+ spark.read.parquet(path.toString).createOrReplaceTempView(" t1" )
48+ checkSparkAnswerAndOperator(
49+ sql(" SELECT array_remove(array(_2, _3,_4), _2) from t1 where _2 is null" ))
50+ checkSparkAnswerAndOperator(
51+ sql(" SELECT array_remove(array(_2, _3,_4), _3) from t1 where _3 is not null" ))
52+ checkSparkAnswerAndOperator(sql(
53+ " SELECT array_remove(case when _2 = _3 THEN array(_2, _3,_4) ELSE null END, _3) from t1" ))
54+ }
5555 }
5656 }
5757 }
58- }
5958 }
6059
6160 test(" array_remove - test all types (native Parquet reader)" ) {
6261 withSQLConf(CometConf .getExprAllowIncompatConfigKey(classOf [ArrayRemove ]) -> " true" ) {
63- withTempDir { dir =>
64- withTempView(" t1" ) {
65- val path = new Path (dir.toURI.toString, " test.parquet" )
66- val filename = path.toString
67- val random = new Random (42 )
68- withSQLConf(CometConf .COMET_ENABLED .key -> " false" ) {
69- ParquetGenerator .makeParquetFile(
70- random,
71- spark,
72- filename,
73- 100 ,
74- SchemaGenOptions (generateArray = false , generateStruct = false , generateMap = false ),
75- DataGenOptions (allowNull = true , generateNegativeZero = true ))
76- }
77- val table = spark.read.parquet(filename)
78- table.createOrReplaceTempView(" t1" )
79- // test with array of each column
80- val fieldNames =
81- table.schema.fields
82- .filter(field => CometArrayRemove .isTypeSupported(field.dataType))
83- .map(_.name)
84- for (fieldName <- fieldNames) {
85- sql(s " SELECT array( $fieldName, $fieldName) as a, $fieldName as b FROM t1 " )
86- .createOrReplaceTempView(" t2" )
87- val df = sql(" SELECT array_remove(a, b) FROM t2" )
88- checkSparkAnswerAndOperator(df)
62+ withTempDir { dir =>
63+ withTempView(" t1" ) {
64+ val path = new Path (dir.toURI.toString, " test.parquet" )
65+ val filename = path.toString
66+ val random = new Random (42 )
67+ withSQLConf(CometConf .COMET_ENABLED .key -> " false" ) {
68+ ParquetGenerator .makeParquetFile(
69+ random,
70+ spark,
71+ filename,
72+ 100 ,
73+ SchemaGenOptions (
74+ generateArray = false ,
75+ generateStruct = false ,
76+ generateMap = false ),
77+ DataGenOptions (allowNull = true , generateNegativeZero = true ))
78+ }
79+ val table = spark.read.parquet(filename)
80+ table.createOrReplaceTempView(" t1" )
81+ // test with array of each column
82+ val fieldNames =
83+ table.schema.fields
84+ .filter(field => CometArrayRemove .isTypeSupported(field.dataType))
85+ .map(_.name)
86+ for (fieldName <- fieldNames) {
87+ sql(s " SELECT array( $fieldName, $fieldName) as a, $fieldName as b FROM t1 " )
88+ .createOrReplaceTempView(" t2" )
89+ val df = sql(" SELECT array_remove(a, b) FROM t2" )
90+ checkSparkAnswerAndOperator(df)
91+ }
8992 }
9093 }
9194 }
@@ -123,7 +126,6 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
123126 }
124127 }
125128 }
126- }
127129 }
128130
129131 test(" array_remove - fallback for unsupported type struct" ) {
@@ -254,16 +256,17 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
254256 withSQLConf(CometConf .getExprAllowIncompatConfigKey(classOf [ArrayContains ]) -> " true" ) {
255257 withTempDir { dir =>
256258 withTempView(" t1" ) {
257- val path = new Path (dir.toURI.toString, " test.parquet" )
258- makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false , n = 10000 )
259- spark.read.parquet(path.toString).createOrReplaceTempView(" t1" );
260- checkSparkAnswerAndOperator(
261- spark.sql(" SELECT array_contains(array(_2, _3, _4), _2) FROM t1" ))
262- checkSparkAnswerAndOperator(
263- spark.sql(" SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1" ));
259+ val path = new Path (dir.toURI.toString, " test.parquet" )
260+ makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false , n = 10000 )
261+ spark.read.parquet(path.toString).createOrReplaceTempView(" t1" );
262+ checkSparkAnswerAndOperator(
263+ spark.sql(" SELECT array_contains(array(_2, _3, _4), _2) FROM t1" ))
264+ checkSparkAnswerAndOperator(
265+ spark.sql(
266+ " SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1" ));
267+ }
264268 }
265269 }
266- }
267270 }
268271
269272 test(" array_contains - test all types (native Parquet reader)" ) {
@@ -272,40 +275,40 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
272275 withTempView(" t1" , " t2" , " t3" ) {
273276 val path = new Path (dir.toURI.toString, " test.parquet" )
274277 val filename = path.toString
275- val random = new Random (42 )
276- withSQLConf(CometConf .COMET_ENABLED .key -> " false" ) {
277- ParquetGenerator .makeParquetFile(
278- random,
279- spark,
280- filename,
281- 100 ,
282- SchemaGenOptions (generateArray = true , generateStruct = true , generateMap = false ),
283- DataGenOptions (allowNull = true , generateNegativeZero = true ))
284- }
285- val table = spark.read.parquet(filename)
286- table.createOrReplaceTempView(" t1" )
287- val complexTypeFields =
288- table.schema.fields.filter(field => isComplexType(field.dataType))
289- val primitiveTypeFields =
290- table.schema.fields.filterNot(field => isComplexType(field.dataType))
291- for (field <- primitiveTypeFields) {
292- val fieldName = field.name
293- val typeName = field.dataType.typeName
294- sql(s " SELECT array( $fieldName, $fieldName) as a, $fieldName as b FROM t1 " )
295- .createOrReplaceTempView(" t2" )
296- checkSparkAnswerAndOperator(sql(" SELECT array_contains(a, b) FROM t2" ))
297- checkSparkAnswerAndOperator(
298- sql(s " SELECT array_contains(a, cast(null as $typeName)) FROM t2 " ))
299- }
300- for (field <- complexTypeFields) {
301- val fieldName = field.name
302- sql(s " SELECT array( $fieldName, $fieldName) as a, $fieldName as b FROM t1 " )
303- .createOrReplaceTempView(" t3" )
304- checkSparkAnswer(sql(" SELECT array_contains(a, b) FROM t3" ))
278+ val random = new Random (42 )
279+ withSQLConf(CometConf .COMET_ENABLED .key -> " false" ) {
280+ ParquetGenerator .makeParquetFile(
281+ random,
282+ spark,
283+ filename,
284+ 100 ,
285+ SchemaGenOptions (generateArray = true , generateStruct = true , generateMap = false ),
286+ DataGenOptions (allowNull = true , generateNegativeZero = true ))
287+ }
288+ val table = spark.read.parquet(filename)
289+ table.createOrReplaceTempView(" t1" )
290+ val complexTypeFields =
291+ table.schema.fields.filter(field => isComplexType(field.dataType))
292+ val primitiveTypeFields =
293+ table.schema.fields.filterNot(field => isComplexType(field.dataType))
294+ for (field <- primitiveTypeFields) {
295+ val fieldName = field.name
296+ val typeName = field.dataType.typeName
297+ sql(s " SELECT array( $fieldName, $fieldName) as a, $fieldName as b FROM t1 " )
298+ .createOrReplaceTempView(" t2" )
299+ checkSparkAnswerAndOperator(sql(" SELECT array_contains(a, b) FROM t2" ))
300+ checkSparkAnswerAndOperator(
301+ sql(s " SELECT array_contains(a, cast(null as $typeName)) FROM t2 " ))
302+ }
303+ for (field <- complexTypeFields) {
304+ val fieldName = field.name
305+ sql(s " SELECT array( $fieldName, $fieldName) as a, $fieldName as b FROM t1 " )
306+ .createOrReplaceTempView(" t3" )
307+ checkSparkAnswer(sql(" SELECT array_contains(a, b) FROM t3" ))
308+ }
305309 }
306310 }
307311 }
308- }
309312 }
310313
311314 test(" array_contains - array literals" ) {
0 commit comments