Skip to content

Commit f0455fd

Browse files
committed
fix: array to array cast
1 parent f6d84b1 commit f0455fd

2 files changed

Lines changed: 41 additions & 8 deletions

File tree

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ use arrow::array::builder::StringBuilder;
4343
use arrow::array::{
4444
BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray,
4545
};
46-
use arrow::compute::can_cast_types;
4746
use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema};
4847
use arrow::datatypes::{Field, Fields, GenericBinaryType};
4948
use arrow::error::ArrowError;
@@ -294,6 +293,7 @@ pub(crate) fn cast_array(
294293
};
295294

296295
let cast_result = match (&from_type, to_type) {
296+
(Null, _) => Ok(cast_with_options(&array, to_type, &native_cast_options)?),
297297
(Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
298298
(LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
299299
(Utf8, Timestamp(_, _)) => {
@@ -366,8 +366,18 @@ pub(crate) fn cast_array(
366366
cast_options,
367367
)?),
368368
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
369-
(List(_), List(_)) if can_cast_types(&from_type, to_type) => {
370-
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
369+
(List(_), List(to)) => {
370+
let list_array = array.as_list::<i32>();
371+
Ok(Arc::new(ListArray::new(
372+
Arc::clone(to),
373+
list_array.offsets().clone(),
374+
cast_array(
375+
Arc::clone(list_array.values()),
376+
to.data_type(),
377+
cast_options,
378+
)?,
379+
list_array.nulls().cloned(),
380+
)) as ArrayRef)
371381
}
372382
(Map(_, _), Map(_, _)) => Ok(cast_map_to_map(&array, &from_type, to_type, cast_options)?),
373383
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
@@ -803,7 +813,8 @@ fn cast_binary_formatter(value: &[u8]) -> String {
803813
#[cfg(test)]
804814
mod tests {
805815
use super::*;
806-
use arrow::array::StringArray;
816+
use arrow::array::{BooleanArray, Decimal128Array, ListArray, StringArray};
817+
use arrow::buffer::OffsetBuffer;
807818
use arrow::datatypes::TimestampMicrosecondType;
808819
use arrow::datatypes::{Field, Fields};
809820
#[test]
@@ -955,8 +966,6 @@ mod tests {
955966

956967
#[test]
957968
fn test_cast_i32_array_to_string() {
958-
use arrow::array::ListArray;
959-
use arrow::buffer::OffsetBuffer;
960969
let values_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]);
961970
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
962971
let item_field = Arc::new(Field::new("item", DataType::Int32, true));

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,30 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12661266
}
12671267
}
12681268

1269+
test("cast ArrayType to ArrayType") {
1270+
val types = Seq(
1271+
BooleanType,
1272+
StringType,
1273+
ByteType,
1274+
IntegerType,
1275+
LongType,
1276+
ShortType,
1277+
DecimalType(10, 2),
1278+
DecimalType(38, 18))
1279+
for (fromType <- types) {
1280+
for (toType <- types) {
1281+
if (fromType != toType &&
1282+
!tags
1283+
.get(s"cast $fromType to $toType")
1284+
.exists(s => s.contains("org.scalatest.Ignore")) &&
1285+
Cast.canCast(fromType, toType) &&
1286+
CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) == Compatible()) {
1287+
castTest(generateArrays(100, fromType), ArrayType(toType))
1288+
}
1289+
}
1290+
}
1291+
}
1292+
12691293
private def generateFloats(): DataFrame = {
12701294
withNulls(gen.generateFloats(dataSize)).toDF("a")
12711295
}
@@ -1294,10 +1318,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12941318
withNulls(gen.generateLongs(dataSize)).toDF("a")
12951319
}
12961320

1297-
private def generateArrays(rowSize: Int, elementType: DataType): DataFrame = {
1321+
private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = {
12981322
import scala.collection.JavaConverters._
12991323
val schema = StructType(Seq(StructField("a", ArrayType(elementType), true)))
1300-
spark.createDataFrame(gen.generateRows(rowSize, schema).asJava, schema)
1324+
spark.createDataFrame(gen.generateRows(rowNum, schema).asJava, schema)
13011325
}
13021326

13031327
// https://github.com/apache/datafusion-comet/issues/2038

0 commit comments

Comments
 (0)