Skip to content

Commit 668f7a3

Browse files
authored
Merge branch 'main' into cast_module_refactor_boolean
2 parents 520ba4d + 219859b commit 668f7a3

5 files changed

Lines changed: 328 additions & 20 deletions

File tree

native/spark-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ harness = false
9292
name = "to_csv"
9393
harness = false
9494

95+
[[bench]]
96+
name = "cast_int_to_timestamp"
97+
harness = false
98+
9599
[[test]]
96100
name = "test_udf_registration"
97101
path = "tests/spark_expr_reg.rs"
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::builder::{Int16Builder, Int32Builder, Int64Builder, Int8Builder};
19+
use arrow::array::RecordBatch;
20+
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
21+
use criterion::{criterion_group, criterion_main, Criterion};
22+
use datafusion::physical_expr::{expressions::Column, PhysicalExpr};
23+
use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
24+
use std::sync::Arc;
25+
26+
const BATCH_SIZE: usize = 8192;
27+
28+
fn criterion_benchmark(c: &mut Criterion) {
29+
// Test with UTC timezone
30+
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
31+
let timestamp_type = DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into()));
32+
33+
let mut group = c.benchmark_group("cast_int_to_timestamp");
34+
35+
// Int8 -> Timestamp
36+
let batch_i8 = create_int8_batch();
37+
let expr_i8 = Arc::new(Column::new("a", 0));
38+
let cast_i8_to_ts = Cast::new(expr_i8, timestamp_type.clone(), spark_cast_options.clone());
39+
group.bench_function("cast_i8_to_timestamp", |b| {
40+
b.iter(|| cast_i8_to_ts.evaluate(&batch_i8).unwrap());
41+
});
42+
43+
// Int16 -> Timestamp
44+
let batch_i16 = create_int16_batch();
45+
let expr_i16 = Arc::new(Column::new("a", 0));
46+
let cast_i16_to_ts = Cast::new(expr_i16, timestamp_type.clone(), spark_cast_options.clone());
47+
group.bench_function("cast_i16_to_timestamp", |b| {
48+
b.iter(|| cast_i16_to_ts.evaluate(&batch_i16).unwrap());
49+
});
50+
51+
// Int32 -> Timestamp
52+
let batch_i32 = create_int32_batch();
53+
let expr_i32 = Arc::new(Column::new("a", 0));
54+
let cast_i32_to_ts = Cast::new(expr_i32, timestamp_type.clone(), spark_cast_options.clone());
55+
group.bench_function("cast_i32_to_timestamp", |b| {
56+
b.iter(|| cast_i32_to_ts.evaluate(&batch_i32).unwrap());
57+
});
58+
59+
// Int64 -> Timestamp
60+
let batch_i64 = create_int64_batch();
61+
let expr_i64 = Arc::new(Column::new("a", 0));
62+
let cast_i64_to_ts = Cast::new(expr_i64, timestamp_type.clone(), spark_cast_options.clone());
63+
group.bench_function("cast_i64_to_timestamp", |b| {
64+
b.iter(|| cast_i64_to_ts.evaluate(&batch_i64).unwrap());
65+
});
66+
67+
group.finish();
68+
}
69+
70+
fn create_int8_batch() -> RecordBatch {
71+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int8, true)]));
72+
let mut b = Int8Builder::with_capacity(BATCH_SIZE);
73+
for i in 0..BATCH_SIZE {
74+
if i % 10 == 0 {
75+
b.append_null();
76+
} else {
77+
b.append_value(rand::random::<i8>());
78+
}
79+
}
80+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
81+
}
82+
83+
fn create_int16_batch() -> RecordBatch {
84+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int16, true)]));
85+
let mut b = Int16Builder::with_capacity(BATCH_SIZE);
86+
for i in 0..BATCH_SIZE {
87+
if i % 10 == 0 {
88+
b.append_null();
89+
} else {
90+
b.append_value(rand::random::<i16>());
91+
}
92+
}
93+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
94+
}
95+
96+
fn create_int32_batch() -> RecordBatch {
97+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));
98+
let mut b = Int32Builder::with_capacity(BATCH_SIZE);
99+
for i in 0..BATCH_SIZE {
100+
if i % 10 == 0 {
101+
b.append_null();
102+
} else {
103+
b.append_value(rand::random::<i32>());
104+
}
105+
}
106+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
107+
}
108+
109+
fn create_int64_batch() -> RecordBatch {
110+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
111+
let mut b = Int64Builder::with_capacity(BATCH_SIZE);
112+
for i in 0..BATCH_SIZE {
113+
if i % 10 == 0 {
114+
b.append_null();
115+
} else {
116+
b.append_value(rand::random::<i64>());
117+
}
118+
}
119+
RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
120+
}
121+
122+
fn config() -> Criterion {
123+
Criterion::default()
124+
}
125+
126+
criterion_group! {
127+
name = benches;
128+
config = config();
129+
targets = criterion_benchmark
130+
}
131+
criterion_main!(benches);

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,23 @@ macro_rules! cast_decimal_to_int32_up {
610610
}};
611611
}
612612

613+
macro_rules! cast_int_to_timestamp_impl {
614+
($array:expr, $builder:expr, $primitive_type:ty) => {{
615+
let arr = $array.as_primitive::<$primitive_type>();
616+
for i in 0..arr.len() {
617+
if arr.is_null(i) {
618+
$builder.append_null();
619+
} else {
620+
// saturating_mul limits to i64::MIN/MAX on overflow instead of panicking,
621+
// which could occur when converting extreme values (e.g., Long.MIN_VALUE)
622+
// matching spark behavior (irrespective of EvalMode)
623+
let micros = (arr.value(i) as i64).saturating_mul(MICROS_PER_SECOND);
624+
$builder.append_value(micros);
625+
}
626+
}
627+
}};
628+
}
629+
613630
// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly
614631
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
615632
let (sign, rest) = match value_str.strip_prefix('-') {
@@ -912,6 +929,7 @@ pub(crate) fn cast_array(
912929
(Boolean, Decimal128(precision, scale)) => {
913930
cast_boolean_to_decimal(&array, *precision, *scale)
914931
}
932+
(Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) => cast_int_to_timestamp(&array, tz),
915933
_ if cast_options.is_adapting_schema
916934
|| is_datafusion_spark_compatible(from_type, to_type) =>
917935
{
@@ -930,6 +948,29 @@ pub(crate) fn cast_array(
930948
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
931949
}
932950

951+
fn cast_int_to_timestamp(
952+
array_ref: &ArrayRef,
953+
target_tz: &Option<Arc<str>>,
954+
) -> SparkResult<ArrayRef> {
955+
// Input is seconds since epoch, multiply by MICROS_PER_SECOND to get microseconds.
956+
let mut builder = TimestampMicrosecondBuilder::with_capacity(array_ref.len());
957+
958+
match array_ref.data_type() {
959+
DataType::Int8 => cast_int_to_timestamp_impl!(array_ref, builder, Int8Type),
960+
DataType::Int16 => cast_int_to_timestamp_impl!(array_ref, builder, Int16Type),
961+
DataType::Int32 => cast_int_to_timestamp_impl!(array_ref, builder, Int32Type),
962+
DataType::Int64 => cast_int_to_timestamp_impl!(array_ref, builder, Int64Type),
963+
dt => {
964+
return Err(SparkError::Internal(format!(
965+
"Unsupported type for cast_int_to_timestamp: {:?}",
966+
dt
967+
)))
968+
}
969+
}
970+
971+
Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef)
972+
}
973+
933974
fn cast_date_to_timestamp(
934975
array_ref: &ArrayRef,
935976
cast_options: &SparkCastOptions,
@@ -3399,4 +3440,94 @@ mod tests {
33993440
assert_eq!(r#"[null]"#, string_array.value(2));
34003441
assert_eq!(r#"[]"#, string_array.value(3));
34013442
}
3443+
3444+
#[test]
3445+
fn test_cast_int_to_timestamp() {
3446+
let timezones: [Option<Arc<str>>; 6] = [
3447+
Some(Arc::from("UTC")),
3448+
Some(Arc::from("America/New_York")),
3449+
Some(Arc::from("America/Los_Angeles")),
3450+
Some(Arc::from("Europe/London")),
3451+
Some(Arc::from("Asia/Tokyo")),
3452+
Some(Arc::from("Australia/Sydney")),
3453+
];
3454+
3455+
for tz in &timezones {
3456+
let int8_array: ArrayRef = Arc::new(Int8Array::from(vec![
3457+
Some(0),
3458+
Some(1),
3459+
Some(-1),
3460+
Some(127),
3461+
Some(-128),
3462+
None,
3463+
]));
3464+
3465+
let result = cast_int_to_timestamp(&int8_array, tz).unwrap();
3466+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
3467+
3468+
assert_eq!(ts_array.value(0), 0);
3469+
assert_eq!(ts_array.value(1), 1_000_000);
3470+
assert_eq!(ts_array.value(2), -1_000_000);
3471+
assert_eq!(ts_array.value(3), 127_000_000);
3472+
assert_eq!(ts_array.value(4), -128_000_000);
3473+
assert!(ts_array.is_null(5));
3474+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
3475+
3476+
let int16_array: ArrayRef = Arc::new(Int16Array::from(vec![
3477+
Some(0),
3478+
Some(1),
3479+
Some(-1),
3480+
Some(32767),
3481+
Some(-32768),
3482+
None,
3483+
]));
3484+
3485+
let result = cast_int_to_timestamp(&int16_array, tz).unwrap();
3486+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
3487+
3488+
assert_eq!(ts_array.value(0), 0);
3489+
assert_eq!(ts_array.value(1), 1_000_000);
3490+
assert_eq!(ts_array.value(2), -1_000_000);
3491+
assert_eq!(ts_array.value(3), 32_767_000_000_i64);
3492+
assert_eq!(ts_array.value(4), -32_768_000_000_i64);
3493+
assert!(ts_array.is_null(5));
3494+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
3495+
3496+
let int32_array: ArrayRef = Arc::new(Int32Array::from(vec![
3497+
Some(0),
3498+
Some(1),
3499+
Some(-1),
3500+
Some(1704067200),
3501+
None,
3502+
]));
3503+
3504+
let result = cast_int_to_timestamp(&int32_array, tz).unwrap();
3505+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
3506+
3507+
assert_eq!(ts_array.value(0), 0);
3508+
assert_eq!(ts_array.value(1), 1_000_000);
3509+
assert_eq!(ts_array.value(2), -1_000_000);
3510+
assert_eq!(ts_array.value(3), 1_704_067_200_000_000_i64);
3511+
assert!(ts_array.is_null(4));
3512+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
3513+
3514+
let int64_array: ArrayRef = Arc::new(Int64Array::from(vec![
3515+
Some(0),
3516+
Some(1),
3517+
Some(-1),
3518+
Some(i64::MAX),
3519+
Some(i64::MIN),
3520+
]));
3521+
3522+
let result = cast_int_to_timestamp(&int64_array, tz).unwrap();
3523+
let ts_array = result.as_primitive::<TimestampMicrosecondType>();
3524+
3525+
assert_eq!(ts_array.value(0), 0);
3526+
assert_eq!(ts_array.value(1), 1_000_000_i64);
3527+
assert_eq!(ts_array.value(2), -1_000_000_i64);
3528+
assert_eq!(ts_array.value(3), i64::MAX);
3529+
assert_eq!(ts_array.value(4), i64::MIN);
3530+
assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
3531+
}
3532+
}
34023533
}

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
299299
Compatible()
300300
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
301301
Compatible()
302+
case DataTypes.TimestampType =>
303+
Compatible()
302304
case _ =>
303305
unsupported(DataTypes.ByteType, toType)
304306
}
@@ -313,6 +315,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
313315
Compatible()
314316
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
315317
Compatible()
318+
case DataTypes.TimestampType =>
319+
Compatible()
316320
case _ =>
317321
unsupported(DataTypes.ShortType, toType)
318322
}
@@ -328,6 +332,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
328332
case _: DecimalType =>
329333
Compatible()
330334
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible()
335+
case DataTypes.TimestampType =>
336+
Compatible()
331337
case _ =>
332338
unsupported(DataTypes.IntegerType, toType)
333339
}
@@ -343,6 +349,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
343349
case _: DecimalType =>
344350
Compatible()
345351
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible()
352+
case DataTypes.TimestampType =>
353+
Compatible()
346354
case _ =>
347355
unsupported(DataTypes.LongType, toType)
348356
}

0 commit comments

Comments
 (0)