Skip to content

Commit 5aa49b0

Browse files
author
B Vadlamani
committed
non_numeric_to_timestamp
1 parent 5651fdc commit 5aa49b0

4 files changed

Lines changed: 49 additions & 23 deletions

File tree

native/spark-expr/benches/cast_non_int_numeric_timestamp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ fn create_decimal128_batch() -> RecordBatch {
117117
if i % 10 == 0 {
118118
b.append_null();
119119
} else {
120-
b.append_value(rand::random::<i64>() as i128);
120+
b.append_value(i as i128 * 1_000_000);
121121
}
122122
}
123123
let array = b.finish().with_precision_and_scale(18, 6).unwrap();

spark/src/main/scala/org/apache/comet/expressions/CometCast.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
7070
case _ =>
7171
if (isAlwaysCastToNull(cast.child.dataType, cast.dataType, cometEvalMode)) {
7272
exprToProtoInternal(Literal.create(null, cast.dataType), inputs, binding)
73-
} else if (isAlwaysCastToUTC(cast.child.dataType, cast.dataType, cometEvalMode)) {
74-
exprToProtoInternal(Literal.create(0L, cast.dataType), inputs, binding)
75-
} else {
73+
}
74+
else {
7675
val childExpr = exprToProtoInternal(cast.child, inputs, binding)
7776
if (childExpr.isDefined) {
7877
castToProto(cast, cast.timeZoneId, cast.dataType, childExpr.get, cometEvalMode)
@@ -84,17 +83,6 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
8483
}
8584
}
8685

87-
private def isAlwaysCastToUTC(
88-
fromType: DataType,
89-
toType: DataType,
90-
evalMode: CometEvalMode.Value): Boolean = {
91-
(fromType, toType) match {
92-
case (DataTypes.BooleanType, DataTypes.TimestampType) if evalMode == CometEvalMode.ANSI =>
93-
true
94-
case _ => false
95-
}
96-
}
97-
9886
// Some casts like date -> int/byte / long are always null. Terminate early in planning
9987
private def isAlwaysCastToNull(
10088
fromType: DataType,

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
167167
castTest(generateBools(), DataTypes.StringType)
168168
}
169169

170-
ignore("cast BooleanType to TimestampType") {
170+
test("cast BooleanType to TimestampType") {
171171
// Arrow error: Cast error: Casting from Boolean to Timestamp(Microsecond, Some("UTC")) not supported
172172
castTest(generateBools(), DataTypes.TimestampType)
173173
}
@@ -507,7 +507,11 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
507507
test("cast FloatType to TimestampType") {
508508
compatibleTimezones.foreach { tz =>
509509
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
510-
castTest(generateFloats(), DataTypes.TimestampType)
510+
withTable("t1") {
511+
generateFloats().write.saveAsTable("t1")
512+
val df = spark.sql("select a, cast(a as timestamp) from t1")
513+
assertDataFrameEquals(df)
514+
}
511515
}
512516
}
513517
}
@@ -564,13 +568,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
564568
}
565569

566570
test("cast DoubleType to TimestampType") {
567-
// Cast back to double avoids java.sql.Timestamp overflow during collect() for extreme values
568571
compatibleTimezones.foreach { tz =>
569572
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
570573
withTable("t1") {
571-
generateLongs().write.saveAsTable("t1")
572-
val df = spark.sql("select a, cast(cast(a as timestamp) as double) from t1")
573-
checkSparkAnswerAndOperator(df)
574+
generateDoubles().write.saveAsTable("t1")
575+
val df = spark.sql("select a, cast(a as timestamp) from t1")
576+
assertDataFrameEquals(df)
574577
}
575578
}
576579
}
@@ -638,8 +641,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
638641
castTest(generateDecimalsPrecision10Scale2(), DataTypes.StringType)
639642
}
640643

641-
ignore("cast DecimalType(10,2) to TimestampType") {
642-
// input: -123456.789000000000000000, expected: 1969-12-30 05:42:23.211, actual: 1969-12-31 15:59:59.876544
644+
test("cast DecimalType(10,2) to TimestampType") {
643645
castTest(generateDecimalsPrecision10Scale2(), DataTypes.TimestampType)
644646
}
645647

spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,4 +1276,40 @@ abstract class CometTestBase
12761276
!usingLegacyNativeCometScan(conf) &&
12771277
CometConf.COMET_PARQUET_UNSIGNED_SMALL_INT_CHECK.get(conf)
12781278
}
1279+
1280+
/**
1281+
* Uses except (difference) to find differences without using collect()
1282+
* Checks cometDF and sparkDF including schemas
1283+
*/
1284+
protected def assertDataFrameEquals(
1285+
df: => DataFrame,
1286+
assertCometNative: Boolean = true): Unit = {
1287+
1288+
var sparkDf: DataFrame = null
1289+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
1290+
sparkDf = datasetOfRows(spark, df.logicalPlan)
1291+
}
1292+
val cometDf = datasetOfRows(spark, df.logicalPlan)
1293+
1294+
// Compare schemas
1295+
assert(
1296+
sparkDf.schema == cometDf.schema,
1297+
s"Schema mismatch:\nCorrect Answer: ${sparkDf.schema}\nSpark Answer: ${cometDf.schema}")
1298+
1299+
// Use except (difference) to compare DataFrames without collect() which error on extremely high Timestamp values
1300+
val sparkMinusComet = sparkDf.except(cometDf)
1301+
val cometMinusSpark = cometDf.except(sparkDf)
1302+
1303+
val diffCount1 = sparkMinusComet.count()
1304+
val diffCount2 = cometMinusSpark.count()
1305+
1306+
if (diffCount1 > 0 || diffCount2 > 0) {
1307+
fail(
1308+
s"DataFrames count doesnt match.\n" )
1309+
}
1310+
1311+
if (assertCometNative) {
1312+
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
1313+
}
1314+
}
12791315
}

0 commit comments

Comments
 (0)