Skip to content

Commit 97a40df

Browse files
committed
support_timestamp_to_int_type
1 parent d3ea9fd commit 97a40df

4 files changed

Lines changed: 137 additions & 21 deletions

File tree

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use arrow::array::{
2525
PrimitiveBuilder, StringArray, StructArray, TimestampMicrosecondBuilder,
2626
};
2727
use arrow::compute::can_cast_types;
28+
use arrow::datatypes::DataType::Int64;
2829
use arrow::datatypes::{
2930
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType,
3031
Schema,
@@ -915,6 +916,9 @@ fn cast_array(
915916
(Boolean, Decimal128(precision, scale)) => {
916917
cast_boolean_to_decimal(&array, *precision, *scale)
917918
}
919+
(Int8 | Int16 | Int32 | Int64, Timestamp(_, _)) => {
920+
cast_int_to_timestamp(&array, cast_options)
921+
}
918922
_ if cast_options.is_adapting_schema
919923
|| is_datafusion_spark_compatible(from_type, to_type) =>
920924
{
@@ -933,6 +937,29 @@ fn cast_array(
933937
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
934938
}
935939

940+
fn cast_int_to_timestamp(
941+
array_ref: &ArrayRef,
942+
cast_options: &SparkCastOptions,
943+
) -> SparkResult<ArrayRef> {
944+
// Input is seconds since epoch, multiply by MICROS_PER_SECOND to get microseconds.
945+
let int64_array = cast_with_options(&array_ref, &Int64, &CAST_OPTIONS)?;
946+
let int64_arr = int64_array.as_primitive::<Int64Type>();
947+
948+
let mut builder = TimestampMicrosecondBuilder::with_capacity(int64_arr.len());
949+
for i in 0..int64_arr.len() {
950+
if int64_arr.is_null(i) {
951+
builder.append_null();
952+
} else {
953+
let micros = int64_arr.value(i).saturating_mul(MICROS_PER_SECOND);
954+
builder.append_value(micros);
955+
}
956+
}
957+
958+
// input tz is always defined or set to UTC on spark side
959+
let tz: Arc<str> = Arc::from(cast_options.timezone.as_str());
960+
Ok(Arc::new(builder.finish().with_timezone(tz)) as ArrayRef)
961+
}
962+
936963
fn cast_date_to_timestamp(
937964
array_ref: &ArrayRef,
938965
cast_options: &SparkCastOptions,

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
249249
private def canCastFromTimestamp(toType: DataType): SupportLevel = {
250250
toType match {
251251
case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType |
252-
DataTypes.IntegerType =>
252+
DataTypes.IntegerType =>
253253
// https://github.com/apache/datafusion-comet/issues/352
254254
// this seems like an edge case that isn't important for us to support
255255
unsupported(DataTypes.TimestampType, toType)
@@ -279,6 +279,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
279279
Compatible()
280280
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
281281
Compatible()
282+
case DataTypes.TimestampType =>
283+
Compatible()
282284
case _ =>
283285
unsupported(DataTypes.ByteType, toType)
284286
}
@@ -293,6 +295,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
293295
Compatible()
294296
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
295297
Compatible()
298+
case DataTypes.TimestampType =>
299+
Compatible()
296300
case _ =>
297301
unsupported(DataTypes.ShortType, toType)
298302
}
@@ -308,6 +312,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
308312
case _: DecimalType =>
309313
Compatible()
310314
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible()
315+
case DataTypes.TimestampType =>
316+
Compatible()
311317
case _ =>
312318
unsupported(DataTypes.IntegerType, toType)
313319
}
@@ -323,6 +329,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
323329
case _: DecimalType =>
324330
Compatible()
325331
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => Compatible()
332+
case DataTypes.TimestampType =>
333+
Compatible()
326334
case _ =>
327335
unsupported(DataTypes.LongType, toType)
328336
}

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
223223
testTry = false)
224224
}
225225

226-
ignore("cast ByteType to TimestampType") {
227-
// input: -1, expected: 1969-12-31 15:59:59.0, actual: 1969-12-31 15:59:59.999999
228-
castTest(
229-
generateBytes(),
230-
DataTypes.TimestampType,
231-
hasIncompatibleType = usingParquetExecWithIncompatTypes)
226+
test("cast ByteType to TimestampType") {
227+
val compatibleTimezones = Seq(
228+
"UTC",
229+
"America/New_York",
230+
"America/Los_Angeles",
231+
"Europe/London",
232+
"Asia/Tokyo",
233+
"Australia/Sydney")
234+
compatibleTimezones.foreach { tz =>
235+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
236+
castTest(
237+
generateBytes(),
238+
DataTypes.TimestampType,
239+
hasIncompatibleType = usingParquetExecWithIncompatTypes)
240+
}
241+
}
232242
}
233243

234244
// CAST from ShortType
@@ -300,12 +310,22 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
300310
testTry = false)
301311
}
302312

303-
ignore("cast ShortType to TimestampType") {
304-
// input: -1003, expected: 1969-12-31 15:43:17.0, actual: 1969-12-31 15:59:59.998997
305-
castTest(
306-
generateShorts(),
307-
DataTypes.TimestampType,
308-
hasIncompatibleType = usingParquetExecWithIncompatTypes)
313+
test("cast ShortType to TimestampType") {
314+
val compatibleTimezones = Seq(
315+
"UTC",
316+
"America/New_York",
317+
"America/Los_Angeles",
318+
"Europe/London",
319+
"Asia/Tokyo",
320+
"Australia/Sydney")
321+
compatibleTimezones.foreach { tz =>
322+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
323+
castTest(
324+
generateShorts(),
325+
DataTypes.TimestampType,
326+
hasIncompatibleType = usingParquetExecWithIncompatTypes)
327+
}
328+
}
309329
}
310330

311331
// CAST from integer
@@ -363,9 +383,19 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
363383
castTest(generateInts(), DataTypes.BinaryType, testAnsi = false, testTry = false)
364384
}
365385

366-
ignore("cast IntegerType to TimestampType") {
367-
// input: -1000479329, expected: 1938-04-19 01:04:31.0, actual: 1969-12-31 15:43:19.520671
368-
castTest(generateInts(), DataTypes.TimestampType)
386+
test("cast IntegerType to TimestampType") {
387+
val compatibleTimezones = Seq(
388+
"UTC",
389+
"America/New_York",
390+
"America/Los_Angeles",
391+
"Europe/London",
392+
"Asia/Tokyo",
393+
"Australia/Sydney")
394+
compatibleTimezones.foreach { tz =>
395+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
396+
castTest(generateInts(), DataTypes.TimestampType)
397+
}
398+
}
369399
}
370400

371401
// CAST from LongType
@@ -410,9 +440,26 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
410440
castTest(generateLongs(), DataTypes.BinaryType, testAnsi = false, testTry = false)
411441
}
412442

413-
ignore("cast LongType to TimestampType") {
414-
// java.lang.ArithmeticException: long overflow
415-
castTest(generateLongs(), DataTypes.TimestampType)
443+
test("cast LongType to TimestampType") {
444+
// Use assertDataFrameEquals because extreme Long values (Long.MIN_VALUE, Long.MAX_VALUE)
445+
// overflow when converted to java.sql.Timestamp during collect(), but the cast itself works.
446+
val compatibleTimezones = Seq(
447+
"UTC",
448+
"America/New_York",
449+
"America/Los_Angeles",
450+
"Europe/London",
451+
"Asia/Tokyo",
452+
"Australia/Sydney")
453+
compatibleTimezones.foreach { tz =>
454+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
455+
withTempPath { dir =>
456+
val input = generateLongs()
457+
val data = roundtripParquet(input, dir).coalesce(1)
458+
val df = data.withColumn("ts", col("a").cast(DataTypes.TimestampType))
459+
assertDataFrameEquals(df)
460+
}
461+
}
462+
}
416463
}
417464

418465
// CAST from FloatType
@@ -1042,13 +1089,13 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10421089

10431090
ignore("cast TimestampType to ShortType") {
10441091
// https://github.com/apache/datafusion-comet/issues/352
1045-
// input: 2023-12-31 10:00:00.0, expected: -21472, actual: null]
1092+
// input: 2023-12-31 10:00:00.0, expected: -21472, actual: null
10461093
castTest(generateTimestamps(), DataTypes.ShortType)
10471094
}
10481095

10491096
ignore("cast TimestampType to IntegerType") {
10501097
// https://github.com/apache/datafusion-comet/issues/352
1051-
// input: 2023-12-31 10:00:00.0, expected: 1704045600, actual: null]
1098+
// input: 2023-12-31 10:00:00.0, expected: 1704045600, actual: null
10521099
castTest(generateTimestamps(), DataTypes.IntegerType)
10531100
}
10541101

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,40 @@ abstract class CometTestBase
332332
}
333333
}
334334

335+
protected def assertDataFrameEquals(
336+
df: => DataFrame,
337+
checkNativeOperators: Boolean = true): Unit = {
338+
339+
var sparkDf: DataFrame = null
340+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
341+
sparkDf = datasetOfRows(spark, df.logicalPlan)
342+
}
343+
val cometDf = datasetOfRows(spark, df.logicalPlan)
344+
345+
// Check schemas match
346+
assert(
347+
sparkDf.schema == cometDf.schema,
348+
s"Schemas do not match.\nSpark: ${sparkDf.schema}\nComet: ${cometDf.schema}")
349+
350+
// Compare using except() - this avoids collect() and toJavaTimestamp conversion
351+
val sparkMinusComet = sparkDf.except(cometDf)
352+
val cometMinusSpark = cometDf.except(sparkDf)
353+
354+
val diffCount1 = sparkMinusComet.count()
355+
val diffCount2 = cometMinusSpark.count()
356+
357+
if (diffCount1 != 0 || diffCount2 != 0) {
358+
fail(
359+
"DataFrames are not equal.\n" +
360+
s"Rows in Spark but not in Comet: $diffCount1\n" +
361+
s"Rows in Comet but not in Spark: $diffCount2")
362+
}
363+
364+
if (checkNativeOperators) {
365+
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
366+
}
367+
}
368+
335369
/**
336370
* A helper function for comparing Comet DataFrame with Spark result using absolute tolerance.
337371
*/

0 commit comments

Comments
 (0)