Skip to content

Commit 30dbe74

Browse files
authored
fix: array to array cast (#2897)
1 parent b71b53b commit 30dbe74

File tree

3 files changed

+233
-30
lines changed

3 files changed

+233
-30
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ use crate::{cast_whole_num_to_binary, BinaryOutputStyle};
4141
use crate::{EvalMode, SparkError};
4242
use arrow::array::builder::StringBuilder;
4343
use arrow::array::{
44-
BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray,
44+
new_null_array, BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray,
45+
StringArray, StructArray,
4546
};
46-
use arrow::compute::can_cast_types;
4747
use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema};
4848
use arrow::datatypes::{Field, Fields, GenericBinaryType};
4949
use arrow::error::ArrowError;
@@ -311,6 +311,9 @@ pub(crate) fn cast_array(
311311
};
312312

313313
let cast_result = match (&from_type, to_type) {
314+
// Null arrays carry no concrete values, so Arrow's native cast can change only the
315+
// logical type while preserving length and nullness.
316+
(Null, _) => Ok(cast_with_options(&array, to_type, &native_cast_options)?),
314317
(Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
315318
(LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
316319
(Utf8, Timestamp(_, _)) => cast_string_to_timestamp(
@@ -387,8 +390,25 @@ pub(crate) fn cast_array(
387390
cast_options,
388391
)?),
389392
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
390-
(List(_), List(_)) if can_cast_types(&from_type, to_type) => {
391-
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
393+
(List(_), List(to)) => {
394+
// Cast list elements recursively so nested array casts follow Spark semantics
395+
// instead of relying on Arrow's top-level cast support.
396+
let list_array = array.as_list::<i32>();
397+
let casted_values = match (list_array.values().data_type(), to.data_type()) {
398+
// Spark legacy array casts produce null elements for array<Date> -> array<Int>.
399+
(Date32, Int32) => new_null_array(to.data_type(), list_array.values().len()),
400+
_ => cast_array(
401+
Arc::clone(list_array.values()),
402+
to.data_type(),
403+
cast_options,
404+
)?,
405+
};
406+
Ok(Arc::new(ListArray::new(
407+
Arc::clone(to),
408+
list_array.offsets().clone(),
409+
casted_values,
410+
list_array.nulls().cloned(),
411+
)) as ArrayRef)
392412
}
393413
(Map(_, _), Map(_, _)) => Ok(cast_map_to_map(&array, &from_type, to_type, cast_options)?),
394414
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
@@ -824,7 +844,8 @@ fn cast_binary_formatter(value: &[u8]) -> String {
824844
#[cfg(test)]
825845
mod tests {
826846
use super::*;
827-
use arrow::array::StringArray;
847+
use arrow::array::{ListArray, NullArray, StringArray};
848+
use arrow::buffer::OffsetBuffer;
828849
use arrow::datatypes::TimestampMicrosecondType;
829850
use arrow::datatypes::{Field, Fields};
830851
#[test]
@@ -950,8 +971,6 @@ mod tests {
950971

951972
#[test]
952973
fn test_cast_string_array_to_string() {
953-
use arrow::array::ListArray;
954-
use arrow::buffer::OffsetBuffer;
955974
let values_array =
956975
StringArray::from(vec![Some("a"), Some("b"), Some("c"), Some("a"), None, None]);
957976
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
@@ -976,8 +995,6 @@ mod tests {
976995

977996
#[test]
978997
fn test_cast_i32_array_to_string() {
979-
use arrow::array::ListArray;
980-
use arrow::buffer::OffsetBuffer;
981998
let values_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]);
982999
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
9831000
let item_field = Arc::new(Field::new("item", DataType::Int32, true));
@@ -998,4 +1015,33 @@ mod tests {
9981015
assert_eq!(r#"[null]"#, string_array.value(2));
9991016
assert_eq!(r#"[]"#, string_array.value(3));
10001017
}
1018+
1019+
#[test]
1020+
fn test_cast_array_of_nulls_to_array() {
1021+
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 2, 3, 3].into());
1022+
let from_item_field = Arc::new(Field::new("item", DataType::Null, true));
1023+
let from_array: ArrayRef = Arc::new(ListArray::new(
1024+
from_item_field,
1025+
offsets_buffer,
1026+
Arc::new(NullArray::new(3)),
1027+
None,
1028+
));
1029+
1030+
let to_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
1031+
let to_array = cast_array(
1032+
from_array,
1033+
&to_type,
1034+
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
1035+
)
1036+
.unwrap();
1037+
1038+
let result = to_array.as_list::<i32>();
1039+
assert_eq!(3, result.len());
1040+
assert_eq!(result.value_offsets(), &[0, 2, 3, 3]);
1041+
1042+
let values = result.values().as_primitive::<Int32Type>();
1043+
assert_eq!(3, values.len());
1044+
assert_eq!(3, values.null_count());
1045+
assert!(values.iter().all(|value| value.is_none()));
1046+
}
10011047
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
142142

143143
(fromType, toType) match {
144144
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
145+
case (ArrayType(DataTypes.DateType, _), ArrayType(toElementType, _))
146+
if toElementType != DataTypes.IntegerType && toElementType != DataTypes.StringType =>
147+
unsupported(fromType, toType)
145148
case (dt: ArrayType, DataTypes.StringType) if dt.elementType == DataTypes.BinaryType =>
146149
Incompatible()
147150
case (dt: ArrayType, DataTypes.StringType) =>

0 commit comments

Comments
 (0)