Skip to content

Commit 88e0c83

Browse files
committed
non_numeric_to_timestamp
1 parent 5651fdc commit 88e0c83

5 files changed

Lines changed: 58 additions & 30 deletions

File tree

native/spark-expr/benches/cast_non_int_numeric_timestamp.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,23 @@ fn criterion_benchmark(c: &mut Criterion) {
5050
// Boolean -> Timestamp
5151
let batch_bool = create_boolean_batch();
5252
let expr_bool = Arc::new(Column::new("a", 0));
53-
let cast_bool_to_ts = Cast::new(expr_bool, timestamp_type.clone(), spark_cast_options.clone());
53+
let cast_bool_to_ts = Cast::new(
54+
expr_bool,
55+
timestamp_type.clone(),
56+
spark_cast_options.clone(),
57+
);
5458
group.bench_function("cast_bool_to_timestamp", |b| {
5559
b.iter(|| cast_bool_to_ts.evaluate(&batch_bool).unwrap());
5660
});
5761

5862
// Decimal128 -> Timestamp
5963
let batch_decimal = create_decimal128_batch();
6064
let expr_decimal = Arc::new(Column::new("a", 0));
61-
let cast_decimal_to_ts =
62-
Cast::new(expr_decimal, timestamp_type.clone(), spark_cast_options.clone());
65+
let cast_decimal_to_ts = Cast::new(
66+
expr_decimal,
67+
timestamp_type.clone(),
68+
spark_cast_options.clone(),
69+
);
6370
group.bench_function("cast_decimal_to_timestamp", |b| {
6471
b.iter(|| cast_decimal_to_ts.evaluate(&batch_decimal).unwrap());
6572
});
@@ -117,7 +124,7 @@ fn create_decimal128_batch() -> RecordBatch {
117124
if i % 10 == 0 {
118125
b.append_null();
119126
} else {
120-
b.append_value(rand::random::<i64>() as i128);
127+
b.append_value(i as i128 * 1_000_000);
121128
}
122129
}
123130
let array = b.finish().with_precision_and_scale(18, 6).unwrap();

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3644,11 +3644,8 @@ mod tests {
36443644
];
36453645

36463646
for tz in &timezones {
3647-
let bool_array: ArrayRef = Arc::new(BooleanArray::from(vec![
3648-
Some(true),
3649-
Some(false),
3650-
None,
3651-
]));
3647+
let bool_array: ArrayRef =
3648+
Arc::new(BooleanArray::from(vec![Some(true), Some(false), None]));
36523649

36533650
let result = cast_boolean_to_timestamp(&bool_array, tz).unwrap();
36543651
let ts_array = result.as_primitive::<TimestampMicrosecondType>();

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
7070
case _ =>
7171
if (isAlwaysCastToNull(cast.child.dataType, cast.dataType, cometEvalMode)) {
7272
exprToProtoInternal(Literal.create(null, cast.dataType), inputs, binding)
73-
} else if (isAlwaysCastToUTC(cast.child.dataType, cast.dataType, cometEvalMode)) {
74-
exprToProtoInternal(Literal.create(0L, cast.dataType), inputs, binding)
7573
} else {
7674
val childExpr = exprToProtoInternal(cast.child, inputs, binding)
7775
if (childExpr.isDefined) {
@@ -84,17 +82,6 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
8482
}
8583
}
8684

87-
private def isAlwaysCastToUTC(
88-
fromType: DataType,
89-
toType: DataType,
90-
evalMode: CometEvalMode.Value): Boolean = {
91-
(fromType, toType) match {
92-
case (DataTypes.BooleanType, DataTypes.TimestampType) if evalMode == CometEvalMode.ANSI =>
93-
true
94-
case _ => false
95-
}
96-
}
97-
9885
// Some casts like date -> int/byte / long are always null. Terminate early in planning
9986
private def isAlwaysCastToNull(
10087
fromType: DataType,

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
167167
castTest(generateBools(), DataTypes.StringType)
168168
}
169169

170-
ignore("cast BooleanType to TimestampType") {
170+
test("cast BooleanType to TimestampType") {
171171
// Arrow error: Cast error: Casting from Boolean to Timestamp(Microsecond, Some("UTC")) not supported
172172
castTest(generateBools(), DataTypes.TimestampType)
173173
}
@@ -507,7 +507,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
507507
test("cast FloatType to TimestampType") {
508508
compatibleTimezones.foreach { tz =>
509509
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
510-
castTest(generateFloats(), DataTypes.TimestampType)
510+
withTable("t1") {
511+
generateFloats().write.saveAsTable("t1")
512+
val df = spark.sql("select a, cast(a as timestamp) from t1")
513+
assertDataFrameEquals(df)
514+
}
511515
}
512516
}
513517
}
@@ -564,13 +568,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
564568
}
565569

566570
test("cast DoubleType to TimestampType") {
567-
// Cast back to double avoids java.sql.Timestamp overflow during collect() for extreme values
568571
compatibleTimezones.foreach { tz =>
569572
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
570573
withTable("t1") {
571-
generateLongs().write.saveAsTable("t1")
572-
val df = spark.sql("select a, cast(cast(a as timestamp) as double) from t1")
573-
checkSparkAnswerAndOperator(df)
574+
generateDoubles().write.saveAsTable("t1")
575+
val df = spark.sql("select a, cast(a as timestamp) from t1")
576+
assertDataFrameEquals(df)
574577
}
575578
}
576579
}
@@ -638,8 +641,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
638641
castTest(generateDecimalsPrecision10Scale2(), DataTypes.StringType)
639642
}
640643

641-
ignore("cast DecimalType(10,2) to TimestampType") {
642-
// input: -123456.789000000000000000, expected: 1969-12-30 05:42:23.211, actual: 1969-12-31 15:59:59.876544
644+
test("cast DecimalType(10,2) to TimestampType") {
643645
castTest(generateDecimalsPrecision10Scale2(), DataTypes.TimestampType)
644646
}
645647

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,4 +1276,39 @@ abstract class CometTestBase
12761276
!usingLegacyNativeCometScan(conf) &&
12771277
CometConf.COMET_PARQUET_UNSIGNED_SMALL_INT_CHECK.get(conf)
12781278
}
1279+
1280+
/**
1281+
* Uses except (difference) to find differences without using collect() Checks cometDF and
1282+
* sparkDF including schemas
1283+
*/
1284+
protected def assertDataFrameEquals(
1285+
df: => DataFrame,
1286+
assertCometNative: Boolean = true): Unit = {
1287+
1288+
var sparkDf: DataFrame = null
1289+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
1290+
sparkDf = datasetOfRows(spark, df.logicalPlan)
1291+
}
1292+
val cometDf = datasetOfRows(spark, df.logicalPlan)
1293+
1294+
// Compare schemas
1295+
assert(
1296+
sparkDf.schema == cometDf.schema,
1297+
s"Schema mismatch:\nCorrect Answer: ${sparkDf.schema}\nSpark Answer: ${cometDf.schema}")
1298+
1299+
// Use except (difference) to compare DataFrames without collect() which error on extremely high Timestamp values
1300+
val sparkMinusComet = sparkDf.except(cometDf)
1301+
val cometMinusSpark = cometDf.except(sparkDf)
1302+
1303+
val diffCount1 = sparkMinusComet.count()
1304+
val diffCount2 = cometMinusSpark.count()
1305+
1306+
if (diffCount1 > 0 || diffCount2 > 0) {
1307+
fail("DataFrames count doesnt match.\n")
1308+
}
1309+
1310+
if (assertCometNative) {
1311+
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
1312+
}
1313+
}
12791314
}

0 commit comments

Comments
 (0)