diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index fd0a211b29..63e1c04762 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -92,6 +92,10 @@ harness = false name = "to_csv" harness = false +[[bench]] +name = "cast_int_to_timestamp" +harness = false + [[test]] name = "test_udf_registration" path = "tests/spark_expr_reg.rs" diff --git a/native/spark-expr/benches/cast_int_to_timestamp.rs b/native/spark-expr/benches/cast_int_to_timestamp.rs new file mode 100644 index 0000000000..20143d2b0e --- /dev/null +++ b/native/spark-expr/benches/cast_int_to_timestamp.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::builder::{Int16Builder, Int32Builder, Int64Builder, Int8Builder}; +use arrow::array::RecordBatch; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::physical_expr::{expressions::Column, PhysicalExpr}; +use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; +use std::sync::Arc; + +const BATCH_SIZE: usize = 8192; + +fn criterion_benchmark(c: &mut Criterion) { + // Test with UTC timezone + let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); + let timestamp_type = DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())); + + let mut group = c.benchmark_group("cast_int_to_timestamp"); + + // Int8 -> Timestamp + let batch_i8 = create_int8_batch(); + let expr_i8 = Arc::new(Column::new("a", 0)); + let cast_i8_to_ts = Cast::new(expr_i8, timestamp_type.clone(), spark_cast_options.clone()); + group.bench_function("cast_i8_to_timestamp", |b| { + b.iter(|| cast_i8_to_ts.evaluate(&batch_i8).unwrap()); + }); + + // Int16 -> Timestamp + let batch_i16 = create_int16_batch(); + let expr_i16 = Arc::new(Column::new("a", 0)); + let cast_i16_to_ts = Cast::new(expr_i16, timestamp_type.clone(), spark_cast_options.clone()); + group.bench_function("cast_i16_to_timestamp", |b| { + b.iter(|| cast_i16_to_ts.evaluate(&batch_i16).unwrap()); + }); + + // Int32 -> Timestamp + let batch_i32 = create_int32_batch(); + let expr_i32 = Arc::new(Column::new("a", 0)); + let cast_i32_to_ts = Cast::new(expr_i32, timestamp_type.clone(), spark_cast_options.clone()); + group.bench_function("cast_i32_to_timestamp", |b| { + b.iter(|| cast_i32_to_ts.evaluate(&batch_i32).unwrap()); + }); + + // Int64 -> Timestamp + let batch_i64 = create_int64_batch(); + let expr_i64 = Arc::new(Column::new("a", 0)); + let cast_i64_to_ts = Cast::new(expr_i64, timestamp_type.clone(), spark_cast_options.clone()); + group.bench_function("cast_i64_to_timestamp", |b| { + b.iter(|| cast_i64_to_ts.evaluate(&batch_i64).unwrap()); + }); + + group.finish(); +} + +fn create_int8_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int8, true)])); + let mut b = Int8Builder::with_capacity(BATCH_SIZE); + for i in 0..BATCH_SIZE { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap() +} + +fn create_int16_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int16, true)])); + let mut b = Int16Builder::with_capacity(BATCH_SIZE); + for i in 0..BATCH_SIZE { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap() +} + +fn create_int32_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let mut b = Int32Builder::with_capacity(BATCH_SIZE); + for i in 0..BATCH_SIZE { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap() +} + +fn create_int64_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)])); + let mut b = Int64Builder::with_capacity(BATCH_SIZE); + for i in 0..BATCH_SIZE { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap() +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 2809104f26..f5ab83b8a5 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -613,6 +613,23 @@ macro_rules! cast_decimal_to_int32_up { }}; } +macro_rules! cast_int_to_timestamp_impl { + ($array:expr, $builder:expr, $primitive_type:ty) => {{ + let arr = $array.as_primitive::<$primitive_type>(); + for i in 0..arr.len() { + if arr.is_null(i) { + $builder.append_null(); + } else { + // saturating_mul limits to i64::MIN/MAX on overflow instead of panicking, + // which could occur when converting extreme values (e.g., Long.MIN_VALUE) + // matching spark behavior (irrespective of EvalMode) + let micros = (arr.value(i) as i64).saturating_mul(MICROS_PER_SECOND); + $builder.append_value(micros); + } + } + }}; +} + // copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { let (sign, rest) = match value_str.strip_prefix('-') { @@ -915,6 +932,7 @@ fn cast_array( (Boolean, Decimal128(precision, scale)) => { cast_boolean_to_decimal(&array, *precision, *scale) } + (Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) => cast_int_to_timestamp(&array, tz), _ if cast_options.is_adapting_schema || is_datafusion_spark_compatible(from_type, to_type) => { @@ -933,6 +951,29 @@ fn cast_array( Ok(spark_cast_postprocess(cast_result?, from_type, to_type)) } +fn cast_int_to_timestamp( + array_ref: &ArrayRef, + target_tz: &Option>, +) -> SparkResult { + // Input is seconds since epoch, multiply by MICROS_PER_SECOND to get microseconds. + let mut builder = TimestampMicrosecondBuilder::with_capacity(array_ref.len()); + + match array_ref.data_type() { + DataType::Int8 => cast_int_to_timestamp_impl!(array_ref, builder, Int8Type), + DataType::Int16 => cast_int_to_timestamp_impl!(array_ref, builder, Int16Type), + DataType::Int32 => cast_int_to_timestamp_impl!(array_ref, builder, Int32Type), + DataType::Int64 => cast_int_to_timestamp_impl!(array_ref, builder, Int64Type), + dt => { + return Err(SparkError::Internal(format!( + "Unsupported type for cast_int_to_timestamp: {:?}", + dt + ))) + } + } + + Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef) +} + fn cast_date_to_timestamp( array_ref: &ArrayRef, cast_options: &SparkCastOptions, @@ -3519,4 +3560,94 @@ mod tests { assert_eq!(r#"[null]"#, string_array.value(2)); assert_eq!(r#"[]"#, string_array.value(3)); } + + #[test] + fn test_cast_int_to_timestamp() { + let timezones: [Option>; 6] = [ + Some(Arc::from("UTC")), + Some(Arc::from("America/New_York")), + Some(Arc::from("America/Los_Angeles")), + Some(Arc::from("Europe/London")), + Some(Arc::from("Asia/Tokyo")), + Some(Arc::from("Australia/Sydney")), + ]; + + for tz in &timezones { + let int8_array: ArrayRef = Arc::new(Int8Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(127), + Some(-128), + None, + ])); + + let result = cast_int_to_timestamp(&int8_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 127_000_000); + assert_eq!(ts_array.value(4), -128_000_000); + assert!(ts_array.is_null(5)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + let int16_array: ArrayRef = Arc::new(Int16Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(32767), + Some(-32768), + None, + ])); + + let result = cast_int_to_timestamp(&int16_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 32_767_000_000_i64); + assert_eq!(ts_array.value(4), -32_768_000_000_i64); + assert!(ts_array.is_null(5)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + let int32_array: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(1704067200), + None, + ])); + + let result = cast_int_to_timestamp(&int32_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 1_704_067_200_000_000_i64); + assert!(ts_array.is_null(4)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + let int64_array: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(i64::MAX), + Some(i64::MIN), + ])); + + let result = cast_int_to_timestamp(&int64_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000_i64); + assert_eq!(ts_array.value(2), -1_000_000_i64); + assert_eq!(ts_array.value(3), i64::MAX); + assert_eq!(ts_array.value(4), i64::MIN); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + } + } } diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 8cbe76a19d..15dfcb2d7c 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -299,6 +299,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { Compatible() case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible() + case DataTypes.TimestampType => + Compatible() case _ => unsupported(DataTypes.ByteType, toType) } @@ -313,6 +315,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { Compatible() case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible() + case DataTypes.TimestampType => + Compatible() case _ => unsupported(DataTypes.ShortType, toType) } @@ -328,6 +332,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { case _: DecimalType => Compatible() case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible() + case DataTypes.TimestampType => + Compatible() case _ => unsupported(DataTypes.IntegerType, toType) } @@ -343,6 +349,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { case _: DecimalType => Compatible() case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible() + case DataTypes.TimestampType => + Compatible() case _ => unsupported(DataTypes.LongType, toType) } diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 326904d564..72c2390d71 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -65,6 +65,23 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { lazy val usingParquetExecWithIncompatTypes: Boolean = hasUnsignedSmallIntSafetyCheck(conf) + // Timezone list to check temporal type casts + private val compatibleTimezones = Seq( + "UTC", + "America/New_York", + "America/Chicago", + "America/Denver", + "America/Los_Angeles", + "Europe/London", + "Europe/Paris", + "Europe/Berlin", + "Asia/Tokyo", + "Asia/Shanghai", + "Asia/Singapore", + "Asia/Kolkata", + "Australia/Sydney", + "Pacific/Auckland") + test("all valid cast combinations covered") { val names = testNames @@ -223,12 +240,15 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { testTry = false) } - ignore("cast ByteType to TimestampType") { - // input: -1, expected: 1969-12-31 15:59:59.0, actual: 1969-12-31 15:59:59.999999 - castTest( - generateBytes(), - DataTypes.TimestampType, - hasIncompatibleType = usingParquetExecWithIncompatTypes) + test("cast ByteType to TimestampType") { + compatibleTimezones.foreach { tz => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { + castTest( + generateBytes(), + DataTypes.TimestampType, + hasIncompatibleType = usingParquetExecWithIncompatTypes) + } + } } // CAST from ShortType @@ -300,12 +320,15 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { testTry = false) } - ignore("cast ShortType to TimestampType") { - // input: -1003, expected: 1969-12-31 15:43:17.0, actual: 1969-12-31 15:59:59.998997 - castTest( - generateShorts(), - DataTypes.TimestampType, - hasIncompatibleType = usingParquetExecWithIncompatTypes) + test("cast ShortType to TimestampType") { + compatibleTimezones.foreach { tz => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { + castTest( + generateShorts(), + DataTypes.TimestampType, + hasIncompatibleType = usingParquetExecWithIncompatTypes) + } + } } // CAST from integer @@ -363,9 +386,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateInts(), DataTypes.BinaryType, testAnsi = false, testTry = false) } - ignore("cast IntegerType to TimestampType") { - // input: -1000479329, expected: 1938-04-19 01:04:31.0, actual: 1969-12-31 15:43:19.520671 - castTest(generateInts(), DataTypes.TimestampType) + test("cast IntegerType to TimestampType") { + compatibleTimezones.foreach { tz => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { + castTest(generateInts(), DataTypes.TimestampType) + } + } } // CAST from LongType @@ -410,9 +436,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(generateLongs(), DataTypes.BinaryType, testAnsi = false, testTry = false) } - ignore("cast LongType to TimestampType") { - // java.lang.ArithmeticException: long overflow - castTest(generateLongs(), DataTypes.TimestampType) + test("cast LongType to TimestampType") { + // Cast back to long avoids java.sql.Timestamp overflow during collect() for extreme values + compatibleTimezones.foreach { tz => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { + withTable("t1") { + generateLongs().write.saveAsTable("t1") + val df = spark.sql("select a, cast(cast(a as timestamp) as long) from t1") + checkSparkAnswerAndOperator(df) + } + } + } } // CAST from FloatType @@ -1042,13 +1076,13 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { ignore("cast TimestampType to ShortType") { // https://github.com/apache/datafusion-comet/issues/352 - // input: 2023-12-31 10:00:00.0, expected: -21472, actual: null] + // input: 2023-12-31 10:00:00.0, expected: -21472, actual: null castTest(generateTimestamps(), DataTypes.ShortType) } ignore("cast TimestampType to IntegerType") { // https://github.com/apache/datafusion-comet/issues/352 - // input: 2023-12-31 10:00:00.0, expected: 1704045600, actual: null] + // input: 2023-12-31 10:00:00.0, expected: 1704045600, actual: null castTest(generateTimestamps(), DataTypes.IntegerType) }