Skip to content

Commit d89e50a

Browse files
authored
feat: Support date to timestamp cast (#3383)
1 parent 58cf6e1 commit d89e50a

3 files changed

Lines changed: 189 additions & 5 deletions

File tree

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::{EvalMode, SparkError, SparkResult};
2121
use arrow::array::builder::StringBuilder;
2222
use arrow::array::{
2323
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
24-
PrimitiveBuilder, StringArray, StructArray,
24+
PrimitiveBuilder, StringArray, StructArray, TimestampMicrosecondBuilder,
2525
};
2626
use arrow::compute::can_cast_types;
2727
use arrow::datatypes::{
@@ -1100,6 +1100,7 @@ fn cast_array(
11001100
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
11011101
}
11021102
(Binary, Utf8) => Ok(cast_binary_to_string::<i32>(&array, cast_options)?),
1103+
(Date32, Timestamp(_, tz)) => Ok(cast_date_to_timestamp(&array, cast_options, tz)?),
11031104
_ if cast_options.is_adapting_schema
11041105
|| is_datafusion_spark_compatible(from_type, to_type) =>
11051106
{
@@ -1118,6 +1119,50 @@ fn cast_array(
11181119
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
11191120
}
11201121

1122+
fn cast_date_to_timestamp(
1123+
array_ref: &ArrayRef,
1124+
cast_options: &SparkCastOptions,
1125+
target_tz: &Option<Arc<str>>,
1126+
) -> SparkResult<ArrayRef> {
1127+
let tz_str = if cast_options.timezone.is_empty() {
1128+
"UTC"
1129+
} else {
1130+
cast_options.timezone.as_str()
1131+
};
1132+
// safe to unwrap since we are falling back to UTC above
1133+
let tz = timezone::Tz::from_str(tz_str)?;
1134+
let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
1135+
let date_array = array_ref.as_primitive::<Date32Type>();
1136+
1137+
let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len());
1138+
1139+
for date in date_array.iter() {
1140+
match date {
1141+
Some(date) => {
1142+
// safe to unwrap since chrono's range ( 262,143 yrs) is higher than
1143+
// number of years possible with days as i32 (~ 6 mil yrs)
1144+
// convert date in session timezone to timestamp in UTC
1145+
let naive_date = epoch + chrono::Duration::days(date as i64);
1146+
let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap();
1147+
let local_midnight_in_microsec = tz
1148+
.from_local_datetime(&local_midnight)
1149+
// return earliest possible time (edge case with spring / fall DST changes)
1150+
.earliest()
1151+
.map(|dt| dt.timestamp_micros())
1152+
// in case there is an issue with DST and returns None , we fall back to UTC
1153+
.unwrap_or((date as i64) * 86_400 * 1_000_000);
1154+
builder.append_value(local_midnight_in_microsec);
1155+
}
1156+
None => {
1157+
builder.append_null();
1158+
}
1159+
}
1160+
}
1161+
Ok(Arc::new(
1162+
builder.finish().with_timezone_opt(target_tz.clone()),
1163+
))
1164+
}
1165+
11211166
fn cast_string_to_float(
11221167
array: &ArrayRef,
11231168
to_type: &DataType,
@@ -3408,6 +3453,64 @@ mod tests {
34083453
assert!(result.is_err())
34093454
}
34103455

3456+
#[test]
3457+
fn test_cast_date_to_timestamp() {
3458+
use arrow::array::Date32Array;
3459+
3460+
// verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side)
3461+
let dates: ArrayRef = Arc::new(Date32Array::from(vec![
3462+
Some(0),
3463+
Some(19723),
3464+
Some(19793),
3465+
None,
3466+
]));
3467+
3468+
let non_dst_date = 1704067200000000i64;
3469+
let dst_date = 1710115200000000i64;
3470+
let seven_hours_ts = 25200000000i64;
3471+
let eight_hours_ts = 28800000000i64;
3472+
3473+
// validate UTC
3474+
let result = cast_array(
3475+
Arc::clone(&dates),
3476+
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
3477+
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
3478+
)
3479+
.unwrap();
3480+
let ts = result.as_primitive::<TimestampMicrosecondType>();
3481+
assert_eq!(ts.value(0), 0);
3482+
assert_eq!(ts.value(1), non_dst_date);
3483+
assert_eq!(ts.value(2), dst_date);
3484+
assert!(ts.is_null(3));
3485+
3486+
// validate LA timezone (follows Daylight savings)
3487+
let result = cast_array(
3488+
Arc::clone(&dates),
3489+
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
3490+
&SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles", false),
3491+
)
3492+
.unwrap();
3493+
let ts = result.as_primitive::<TimestampMicrosecondType>();
3494+
assert_eq!(ts.value(0), eight_hours_ts);
3495+
assert_eq!(ts.value(1), non_dst_date + eight_hours_ts);
3496+
// should adjust for DST
3497+
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
3498+
assert!(ts.is_null(3));
3499+
3500+
// Phoenix timezone (does not follow Daylight savings)
3501+
let result = cast_array(
3502+
Arc::clone(&dates),
3503+
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
3504+
&SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix", false),
3505+
)
3506+
.unwrap();
3507+
let ts = result.as_primitive::<TimestampMicrosecondType>();
3508+
assert_eq!(ts.value(0), seven_hours_ts);
3509+
assert_eq!(ts.value(1), non_dst_date + seven_hours_ts);
3510+
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
3511+
assert!(ts.is_null(3));
3512+
}
3513+
34113514
#[test]
34123515
fn test_cast_struct_to_utf8() {
34133516
let a: ArrayRef = Arc::new(Int32Array::from(vec![

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
168168
}
169169
}
170170
Compatible()
171+
case (DataTypes.DateType, toType) => canCastFromDate(toType)
171172
case _ => unsupported(fromType, toType)
172173
}
173174
}
@@ -344,6 +345,12 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
344345
case _ => Unsupported(Some(s"Cast from DecimalType to $toType is not supported"))
345346
}
346347

348+
private def canCastFromDate(toType: DataType): SupportLevel = toType match {
349+
case DataTypes.TimestampType =>
350+
Compatible()
351+
case _ => Unsupported(Some(s"Cast from DateType to $toType is not supported"))
352+
}
353+
347354
private def unsupported(fromType: DataType, toType: DataType): Unsupported = {
348355
Unsupported(Some(s"Cast from $fromType to $toType is not supported"))
349356
}

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -989,9 +989,27 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
989989
castTest(generateDates(), DataTypes.StringType)
990990
}
991991

992-
ignore("cast DateType to TimestampType") {
993-
// Arrow error: Cast error: Casting from Date32 to Timestamp(Microsecond, Some("UTC")) not supported
994-
castTest(generateDates(), DataTypes.TimestampType)
992+
test("cast DateType to TimestampType") {
993+
val compatibleTimezones = Seq(
994+
"UTC",
995+
"America/New_York",
996+
"America/Chicago",
997+
"America/Denver",
998+
"America/Los_Angeles",
999+
"Europe/London",
1000+
"Europe/Paris",
1001+
"Europe/Berlin",
1002+
"Asia/Tokyo",
1003+
"Asia/Shanghai",
1004+
"Asia/Singapore",
1005+
"Asia/Kolkata",
1006+
"Australia/Sydney",
1007+
"Pacific/Auckland")
1008+
compatibleTimezones.map { tz =>
1009+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
1010+
castTest(generateDates(), DataTypes.TimestampType)
1011+
}
1012+
}
9951013
}
9961014

9971015
// CAST from TimestampType
@@ -1264,7 +1282,63 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12641282
}
12651283

12661284
private def generateDates(): DataFrame = {
1267-
val values = Seq("2024-01-01", "999-01-01", "12345-01-01")
1285+
// add 1st, 10th, 20th of each month from epoch to 2027
1286+
val sampledDates = (1970 to 2027).flatMap { year =>
1287+
(1 to 12).flatMap { month =>
1288+
Seq(1, 10, 20).map(day => f"$year-$month%02d-$day%02d")
1289+
}
1290+
}
1291+
1292+
// DST transition dates (1970-2099) for US, EU, Australia
1293+
val dstDates = (1970 to 2099).flatMap { year =>
1294+
Seq(
1295+
// spring forward
1296+
s"$year-03-08",
1297+
s"$year-03-09",
1298+
s"$year-03-10",
1299+
s"$year-03-11",
1300+
s"$year-03-14",
1301+
s"$year-03-15",
1302+
s"$year-03-25",
1303+
s"$year-03-26",
1304+
s"$year-03-27",
1305+
s"$year-03-28",
1306+
s"$year-03-29",
1307+
s"$year-03-30",
1308+
s"$year-03-31",
1309+
// April (Australia fall back)
1310+
s"$year-04-01",
1311+
s"$year-04-02",
1312+
s"$year-04-03",
1313+
s"$year-04-04",
1314+
s"$year-04-05",
1315+
// October (EU fall back and Australia spring forward)
1316+
s"$year-10-01",
1317+
s"$year-10-02",
1318+
s"$year-10-03",
1319+
s"$year-10-04",
1320+
s"$year-10-05",
1321+
s"$year-10-25",
1322+
s"$year-10-26",
1323+
s"$year-10-27",
1324+
s"$year-10-28",
1325+
s"$year-10-29",
1326+
s"$year-10-30",
1327+
s"$year-10-31",
1328+
// US fall back
1329+
s"$year-11-01",
1330+
s"$year-11-02",
1331+
s"$year-11-03",
1332+
s"$year-11-04",
1333+
s"$year-11-05",
1334+
s"$year-11-06",
1335+
s"$year-11-07",
1336+
s"$year-11-08")
1337+
}
1338+
1339+
// Edge cases
1340+
val edgeCases = Seq("1969-12-31", "2000-02-29", "999-01-01", "12345-01-01")
1341+
val values = (sampledDates ++ dstDates ++ edgeCases).distinct
12681342
withNulls(values).toDF("b").withColumn("a", col("b").cast(DataTypes.DateType)).drop("b")
12691343
}
12701344

0 commit comments

Comments
 (0)