diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 1f574f1231..24d5083621 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -530,11 +530,10 @@ fn cast_struct_to_struct( ColumnarValue::from(from_field), to.data_type(), cast_options, - ) - .unwrap(); - cast_result.to_array(array_length).unwrap() + )?; + cast_result.to_array(array_length) }) - .collect(); + .collect::>>()?; Ok(Arc::new(StructArray::new( to_fields.clone(), @@ -961,6 +960,38 @@ mod tests { } } + #[test] + fn test_cast_nested_struct_to_struct_ansi_overflow_returns_error() { + let inner_values: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(128), None])); + let from_nested_fields = + Fields::from(vec![Field::new("long_value", DataType::Int64, true)]); + let nested: ArrayRef = Arc::new(StructArray::new( + from_nested_fields.clone(), + vec![inner_values], + None, + )); + let from_fields = Fields::from(vec![Field::new( + "nested", + DataType::Struct(from_nested_fields), + true, + )]); + let outer: ArrayRef = Arc::new(StructArray::new(from_fields, vec![nested], None)); + + let to_nested_fields = Fields::from(vec![Field::new("byte_value", DataType::Int8, true)]); + let to_fields = Fields::from(vec![Field::new( + "renamed_nested", + DataType::Struct(to_nested_fields), + true, + )]); + let result = spark_cast( + ColumnarValue::Array(outer), + &DataType::Struct(to_fields), + &SparkCastOptions::new(EvalMode::Ansi, "UTC", false), + ); + + assert!(result.is_err()); + } + #[test] fn test_cast_struct_to_struct_drop_column() { let a: ArrayRef = Arc::new(Int32Array::from(vec![ diff --git a/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_ansi.sql b/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_ansi.sql new file mode 100644 index 0000000000..27498ad792 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_ansi.sql @@ -0,0 +1,102 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Config: spark.sql.ansi.enabled=true + +statement +CREATE TABLE test_cast_complex_ansi( + id int, + struct_value struct< + long_value:bigint, + string_value:string, + nested_value:struct>, + array_value array +) USING parquet + +statement +INSERT INTO test_cast_complex_ansi VALUES + ( + 1, + named_struct( + 'long_value', cast(1 as bigint), + 'string_value', 'fits', + 'nested_value', named_struct('inner_long', cast(10 as bigint))), + array(cast(1 as bigint), cast(127 as bigint), cast(null as bigint)) + ), + ( + 2, + named_struct( + 'long_value', cast(128 as bigint), + 'string_value', 'too-large', + 'nested_value', named_struct('inner_long', cast(10 as bigint))), + array(cast(1 as bigint)) + ), + ( + 3, + named_struct( + 'long_value', cast(2 as bigint), + 'string_value', 'nested-too-small', + 'nested_value', named_struct('inner_long', cast(-129 as bigint))), + array(cast(2 as bigint)) + ), + ( + 4, + named_struct( + 'long_value', cast(3 as bigint), + 'string_value', 'array-too-large', + 'nested_value', named_struct('inner_long', cast(4 as bigint))), + array(cast(128 as bigint)) + ), + ( + 5, + cast(null as struct< + long_value:bigint, + string_value:string, + nested_value:struct>), + cast(null as array) + ) + +-- valid complex casts should run natively under ANSI mode +query +SELECT + cast(struct_value as + struct>), + cast(array_value as array), + id +FROM test_cast_complex_ansi +WHERE id IN (1, 5) +ORDER BY id + +-- overflow in a struct field should propagate as a cast error +query expect_error(CAST_OVERFLOW) +SELECT cast(struct_value as + struct>) +FROM test_cast_complex_ansi +WHERE id = 2 + +-- overflow in a nested struct field should propagate as a cast error +query expect_error(CAST_OVERFLOW) +SELECT cast(struct_value as + struct>) +FROM test_cast_complex_ansi +WHERE id = 3 + +-- overflow in an array element should propagate as a cast error +query expect_error(CAST_OVERFLOW) +SELECT cast(array_value as array) +FROM test_cast_complex_ansi +WHERE id = 4 diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index aac1bc0081..482183e914 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -22,6 +22,7 @@ package org.apache.comet import java.io.File import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ import scala.util.Random import org.apache.hadoop.fs.Path @@ -1465,6 +1466,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + val nestedType = + StructType(Seq(StructField("long_value", LongType), StructField("bool_value", BooleanType))) + val structType = StructType( + Seq( + StructField("int_value", IntegerType), + StructField("string_value", StringType), + StructField("nested_value", nestedType))) + val schema = StructType(Seq(StructField("a", structType))) + val rows = Seq( + Row(Row(1, "one", Row(10L, true))), + Row(Row(null, "missing-int", Row(-2L, false))), + Row(Row(3, null, null)), + Row(null)) + + castTest(spark.createDataFrame(rows.asJava, schema), StringType) } test("cast StructType to StructType") { @@ -1479,6 +1496,44 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + val fromNestedType = StructType(Seq(StructField("inner_int", IntegerType))) + val fromType = StructType( + Seq( + StructField("long_value", LongType), + StructField("string_value", StringType), + StructField("nested_value", fromNestedType))) + val toNestedType = StructType(Seq(StructField("renamed_inner_long", LongType))) + val toType = StructType( + Seq( + StructField("renamed_byte", ByteType), + StructField("renamed_string", StringType), + StructField("renamed_nested", toNestedType))) + val schema = StructType(Seq(StructField("a", fromType))) + val rows = Seq( + Row(Row(1L, "one", Row(10))), + Row(Row(127L, null, Row(-20))), + Row(Row(null, "missing-long", null)), + Row(null)) + + castTest(spark.createDataFrame(rows.asJava, schema), toType) + + val overflowFromType = StructType( + Seq(StructField("long_value", LongType), StructField("string_value", StringType))) + val overflowToType = StructType( + Seq(StructField("renamed_byte", ByteType), StructField("renamed_string", StringType))) + val overflowSchema = StructType(Seq(StructField("a", overflowFromType))) + val overflowRows = Seq( + Row(Row(1L, "fits")), + Row(Row(128L, "too-large")), + Row(Row(-129L, "too-small")), + Row(Row(null, "missing-long")), + Row(null)) + + castTest( + spark.createDataFrame(overflowRows.asJava, overflowSchema), + overflowToType, + expectAnsiFailure = true) } test("cast StructType to StructType with different names") { @@ -1564,8 +1619,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("cast ArrayType to StringType - float double binary edge cases") { - import scala.jdk.CollectionConverters._ - def bytes(values: Int*): Array[Byte] = values.map(_.toByte).toArray def arrayInput(elementType: DataType, values: Seq[Any]): DataFrame = { @@ -1630,6 +1683,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { DataTypes.TimestampNTZType, BinaryType) testArrayCastMatrix(types, ArrayType(_), generateArrays(100, _)) + + val schema = StructType(Seq(StructField("a", ArrayType(LongType)))) + val rows = Seq( + Row(Seq[Any](1L, 127L, null)), + Row(Seq[Any](128L)), + Row(Seq[Any](-129L, 0L)), + Row(Seq.empty[Any]), + Row(null)) + + castTest( + spark.createDataFrame(rows.asJava, schema), + ArrayType(ByteType), + expectAnsiFailure = true) } test("cast MapType to MapType") { @@ -1639,7 +1705,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // the planner routes Map→Map casts into it. The map column must be read // natively for the cast to be exercised by Comet, which only happens // under the V1 Parquet scan, so we pin USE_V1_SOURCE_LIST=parquet. - import scala.collection.JavaConverters._ val schema = StructType(Seq(StructField("a", MapType(IntegerType, IntegerType), nullable = true))) val rows = Range(0, 100).map { i => @@ -1837,7 +1902,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = { - import scala.jdk.CollectionConverters._ val schema = StructType(Seq(StructField("a", ArrayType(elementType), true))) def buildRows(values: Seq[Any]): Seq[Row] = { Range(0, rowNum).map { i => @@ -1899,7 +1963,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } private def generateNestedArrays(rowNum: Int, elementType: DataType): DataFrame = { - import scala.jdk.CollectionConverters._ val schema = StructType(Seq(StructField("a", ArrayType(ArrayType(elementType)), true))) val innerArrays = generateArrays(rowNum, elementType) .collect() @@ -2214,6 +2277,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { hasIncompatibleType: Boolean = false, testAnsi: Boolean = true, testTry: Boolean = true, + expectAnsiFailure: Boolean = false, useDataFrameDiff: Boolean = false): Unit = { withTempPath { dir => @@ -2261,11 +2325,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { .select(col("__row_id"), col("a"), col("a").cast(toType).as("converted")) .orderBy(col("__row_id")) .drop("__row_id") + if (expectAnsiFailure) { + assert(!hasIncompatibleType, "Expected ANSI failures must use Comet native execution") + checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan)) + } val res = if (useDataFrameDiff) { assertDataFrameEqualsWithExceptions(df, assertCometNative = !hasIncompatibleType) } else { checkSparkAnswerMaybeThrows(df) } + if (expectAnsiFailure) { + assert(res._1.isDefined, "Expected Spark ANSI cast to fail") + assert(res._2.isDefined, "Expected Comet ANSI cast to fail") + } res match { case (None, None) => // neither system threw an exception