Skip to content

Commit e2f89dd

Browse files
manuzhangcodex
andcommitted
fix: propagate nested cast errors
Return nested struct cast failures through DataFusion errors instead of unwrapping arrays during struct-to-struct casts. Add coverage for nested struct, array, and ANSI overflow cases. Co-authored-by: Codex <codex@openai.com>
1 parent dfc1588 commit e2f89dd

2 files changed

Lines changed: 112 additions & 8 deletions

File tree

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,10 @@ fn cast_struct_to_struct(
530530
ColumnarValue::from(from_field),
531531
to.data_type(),
532532
cast_options,
533-
)
534-
.unwrap();
535-
cast_result.to_array(array_length).unwrap()
533+
)?;
534+
cast_result.to_array(array_length)
536535
})
537-
.collect();
536+
.collect::<DataFusionResult<Vec<_>>>()?;
538537

539538
Ok(Arc::new(StructArray::new(
540539
to_fields.clone(),
@@ -961,6 +960,38 @@ mod tests {
961960
}
962961
}
963962

963+
#[test]
964+
fn test_cast_nested_struct_to_struct_ansi_overflow_returns_error() {
965+
let inner_values: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(128), None]));
966+
let from_nested_fields =
967+
Fields::from(vec![Field::new("long_value", DataType::Int64, true)]);
968+
let nested: ArrayRef = Arc::new(StructArray::new(
969+
from_nested_fields.clone(),
970+
vec![inner_values],
971+
None,
972+
));
973+
let from_fields = Fields::from(vec![Field::new(
974+
"nested",
975+
DataType::Struct(from_nested_fields),
976+
true,
977+
)]);
978+
let outer: ArrayRef = Arc::new(StructArray::new(from_fields, vec![nested], None));
979+
980+
let to_nested_fields = Fields::from(vec![Field::new("byte_value", DataType::Int8, true)]);
981+
let to_fields = Fields::from(vec![Field::new(
982+
"renamed_nested",
983+
DataType::Struct(to_nested_fields),
984+
true,
985+
)]);
986+
let result = spark_cast(
987+
ColumnarValue::Array(outer),
988+
&DataType::Struct(to_fields),
989+
&SparkCastOptions::new(EvalMode::Ansi, "UTC", false),
990+
);
991+
992+
assert!(result.is_err());
993+
}
994+
964995
#[test]
965996
fn test_cast_struct_to_struct_drop_column() {
966997
let a: ArrayRef = Arc::new(Int32Array::from(vec![

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package org.apache.comet
2222
import java.io.File
2323

2424
import scala.collection.mutable.ListBuffer
25+
import scala.jdk.CollectionConverters._
2526
import scala.util.Random
2627

2728
import org.apache.hadoop.fs.Path
@@ -1465,6 +1466,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14651466
}
14661467
}
14671468
}
1469+
1470+
val nestedType =
1471+
StructType(Seq(StructField("long_value", LongType), StructField("bool_value", BooleanType)))
1472+
val structType = StructType(
1473+
Seq(
1474+
StructField("int_value", IntegerType),
1475+
StructField("string_value", StringType),
1476+
StructField("nested_value", nestedType)))
1477+
val schema = StructType(Seq(StructField("a", structType)))
1478+
val rows = Seq(
1479+
Row(Row(1, "one", Row(10L, true))),
1480+
Row(Row(null, "missing-int", Row(-2L, false))),
1481+
Row(Row(3, null, null)),
1482+
Row(null))
1483+
1484+
castTest(spark.createDataFrame(rows.asJava, schema), StringType)
14681485
}
14691486

14701487
test("cast StructType to StructType") {
@@ -1479,6 +1496,44 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
14791496
}
14801497
}
14811498
}
1499+
1500+
val fromNestedType = StructType(Seq(StructField("inner_int", IntegerType)))
1501+
val fromType = StructType(
1502+
Seq(
1503+
StructField("long_value", LongType),
1504+
StructField("string_value", StringType),
1505+
StructField("nested_value", fromNestedType)))
1506+
val toNestedType = StructType(Seq(StructField("renamed_inner_long", LongType)))
1507+
val toType = StructType(
1508+
Seq(
1509+
StructField("renamed_byte", ByteType),
1510+
StructField("renamed_string", StringType),
1511+
StructField("renamed_nested", toNestedType)))
1512+
val schema = StructType(Seq(StructField("a", fromType)))
1513+
val rows = Seq(
1514+
Row(Row(1L, "one", Row(10))),
1515+
Row(Row(127L, null, Row(-20))),
1516+
Row(Row(null, "missing-long", null)),
1517+
Row(null))
1518+
1519+
castTest(spark.createDataFrame(rows.asJava, schema), toType)
1520+
1521+
val overflowFromType = StructType(
1522+
Seq(StructField("long_value", LongType), StructField("string_value", StringType)))
1523+
val overflowToType = StructType(
1524+
Seq(StructField("renamed_byte", ByteType), StructField("renamed_string", StringType)))
1525+
val overflowSchema = StructType(Seq(StructField("a", overflowFromType)))
1526+
val overflowRows = Seq(
1527+
Row(Row(1L, "fits")),
1528+
Row(Row(128L, "too-large")),
1529+
Row(Row(-129L, "too-small")),
1530+
Row(Row(null, "missing-long")),
1531+
Row(null))
1532+
1533+
castTest(
1534+
spark.createDataFrame(overflowRows.asJava, overflowSchema),
1535+
overflowToType,
1536+
expectAnsiFailure = true)
14821537
}
14831538

14841539
test("cast StructType to StructType with different names") {
@@ -1564,8 +1619,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
15641619
}
15651620

15661621
test("cast ArrayType to StringType - float double binary edge cases") {
1567-
import scala.jdk.CollectionConverters._
1568-
15691622
def bytes(values: Int*): Array[Byte] = values.map(_.toByte).toArray
15701623

15711624
def arrayInput(elementType: DataType, values: Seq[Any]): DataFrame = {
@@ -1630,6 +1683,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
16301683
DataTypes.TimestampNTZType,
16311684
BinaryType)
16321685
testArrayCastMatrix(types, ArrayType(_), generateArrays(100, _))
1686+
1687+
val schema = StructType(Seq(StructField("a", ArrayType(LongType))))
1688+
val rows = Seq(
1689+
Row(Seq[Any](1L, 127L, null)),
1690+
Row(Seq[Any](128L)),
1691+
Row(Seq[Any](-129L, 0L)),
1692+
Row(Seq.empty[Any]),
1693+
Row(null))
1694+
1695+
castTest(
1696+
spark.createDataFrame(rows.asJava, schema),
1697+
ArrayType(ByteType),
1698+
expectAnsiFailure = true)
16331699
}
16341700

16351701
test("cast MapType to MapType") {
@@ -1837,7 +1903,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
18371903
}
18381904

18391905
private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = {
1840-
import scala.jdk.CollectionConverters._
18411906
val schema = StructType(Seq(StructField("a", ArrayType(elementType), true)))
18421907
def buildRows(values: Seq[Any]): Seq[Row] = {
18431908
Range(0, rowNum).map { i =>
@@ -1899,7 +1964,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
18991964
}
19001965

19011966
private def generateNestedArrays(rowNum: Int, elementType: DataType): DataFrame = {
1902-
import scala.jdk.CollectionConverters._
19031967
val schema = StructType(Seq(StructField("a", ArrayType(ArrayType(elementType)), true)))
19041968
val innerArrays = generateArrays(rowNum, elementType)
19051969
.collect()
@@ -2214,6 +2278,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
22142278
hasIncompatibleType: Boolean = false,
22152279
testAnsi: Boolean = true,
22162280
testTry: Boolean = true,
2281+
expectAnsiFailure: Boolean = false,
22172282
useDataFrameDiff: Boolean = false): Unit = {
22182283

22192284
withTempPath { dir =>
@@ -2261,11 +2326,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
22612326
.select(col("__row_id"), col("a"), col("a").cast(toType).as("converted"))
22622327
.orderBy(col("__row_id"))
22632328
.drop("__row_id")
2329+
if (expectAnsiFailure) {
2330+
assert(!hasIncompatibleType, "Expected ANSI failures must use Comet native execution")
2331+
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
2332+
}
22642333
val res = if (useDataFrameDiff) {
22652334
assertDataFrameEqualsWithExceptions(df, assertCometNative = !hasIncompatibleType)
22662335
} else {
22672336
checkSparkAnswerMaybeThrows(df)
22682337
}
2338+
if (expectAnsiFailure) {
2339+
assert(res._1.isDefined, "Expected Spark ANSI cast to fail")
2340+
assert(res._2.isDefined, "Expected Comet ANSI cast to fail")
2341+
}
22692342
res match {
22702343
case (None, None) =>
22712344
// neither system threw an exception

0 commit comments

Comments
 (0)