Skip to content

Commit 68b31af

Browse files
committed
[SPARK-57784][SQL] Support the TIME data type in cost-based optimizer statistics estimation
1 parent 08528e9 commit 68b31af

7 files changed

Lines changed: 68 additions & 8 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,7 @@ object CatalogColumnStat extends Logging {
10471047
case TimestampType => getTimestampFormatter(isParsing = true).parse(s)
10481048
case TimestampNTZType =>
10491049
getTimestampFormatter(isParsing = true, forTimestampNTZ = true).parse(s)
1050+
case _: TimeType => TimeFormatter(isParsing = true).parse(s)
10501051
case ByteType => s.toByte
10511052
case ShortType => s.toShort
10521053
case IntegerType => s.toInt
@@ -1073,6 +1074,7 @@ object CatalogColumnStat extends Logging {
10731074
case TimestampNTZType =>
10741075
getTimestampFormatter(isParsing = false, forTimestampNTZ = true)
10751076
.format(v.asInstanceOf[Long])
1077+
case _: TimeType => TimeFormatter(isParsing = false).format(v.asInstanceOf[Long])
10761078
case BooleanType | _: IntegralType | FloatType | DoubleType => v
10771079
case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal
10781080
// This version of Spark does not use min/max for binary/string types so we ignore it.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ object EstimationUtils {
135135
*/
136136
def toDouble(value: Any, dataType: DataType): Double = {
137137
dataType match {
138-
case _: NumericType | DateType | TimestampType => value.toString.toDouble
138+
case _: NumericType | DateType | TimestampType | _: TimeType => value.toString.toDouble
139139
case BooleanType => if (value.asInstanceOf[Boolean]) 1 else 0
140140
}
141141
}
@@ -144,7 +144,7 @@ object EstimationUtils {
144144
dataType match {
145145
case BooleanType => double.toInt == 1
146146
case DateType => double.toInt
147-
case TimestampType => double.toLong
147+
case TimestampType | _: TimeType => double.toLong
148148
case ByteType => double.toByte
149149
case ShortType => double.toShort
150150
case IntegerType => double.toInt

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
291291
}
292292

293293
attr.dataType match {
294-
case _: NumericType | DateType | TimestampType | BooleanType =>
294+
case _: NumericType | DateType | TimestampType | BooleanType | _: TimeType =>
295295
evaluateBinaryForNumeric(op, attr, literal, update)
296296
case StringType | BinaryType =>
297297
// TODO: It is difficult to support other binary comparisons for String/Binary
@@ -413,7 +413,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
413413

414414
// use [min, max] to filter the original hSet
415415
dataType match {
416-
case _: NumericType | BooleanType | DateType | TimestampType =>
416+
case _: NumericType | BooleanType | DateType | TimestampType | _: TimeType =>
417417
if (ndv.toDouble == 0) {
418418
return Some(0.0)
419419
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ object UnionEstimation {
3838
private def isTypeSupported(dt: DataType): Boolean = dt match {
3939
case ByteType | IntegerType | ShortType | FloatType | LongType |
4040
DoubleType | DateType | _: DecimalType | TimestampType | TimestampNTZType |
41-
_: AnsiIntervalType => true
41+
_: AnsiIntervalType | _: TimeType => true
4242
case _ => false
4343
}
4444

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
5454
min = Some(dMin), max = Some(dMax),
5555
nullCount = Some(0), avgLen = Some(4), maxLen = Some(4))
5656

57+
// column ctime has 10 values from 08:00:00 through 17:00:00 (nanos of day).
58+
// 08:00 = 28800000000000L nanos, 17:00 = 61200000000000L nanos
59+
val tMin = DateTimeUtils.localTimeToNanos(java.time.LocalTime.of(8, 0, 0))
60+
val tMax = DateTimeUtils.localTimeToNanos(java.time.LocalTime.of(17, 0, 0))
61+
val attrTime = AttributeReference("ctime", TimeType())()
62+
val colStatTime = ColumnStat(distinctCount = Some(10),
63+
min = Some(tMin), max = Some(tMax),
64+
nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))
65+
5766
// column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
5867
val decMin = Decimal("0.200000000000000000")
5968
val decMax = Decimal("0.800000000000000000")
@@ -118,6 +127,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
118127
attrInt -> colStatInt,
119128
attrBool -> colStatBool,
120129
attrDate -> colStatDate,
130+
attrTime -> colStatTime,
121131
attrDecimal -> colStatDecimal,
122132
attrDouble -> colStatDouble,
123133
attrString -> colStatString,
@@ -523,6 +533,44 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
523533
expectedRowCount = 3)
524534
}
525535

536+
test("ctime = cast('10:00:00' AS TIME)") {
537+
val t10 = DateTimeUtils.localTimeToNanos(java.time.LocalTime.of(10, 0, 0))
538+
validateEstimatedStats(
539+
Filter(EqualTo(attrTime, Literal(t10, TimeType())),
540+
childStatsTestPlan(Seq(attrTime), 10L)),
541+
Seq(attrTime -> ColumnStat(distinctCount = Some(1),
542+
min = Some(t10), max = Some(t10),
543+
nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))),
544+
expectedRowCount = 1)
545+
}
546+
547+
test("ctime < cast('12:00:00' AS TIME)") {
548+
// 12:00 is 43200000000000L nanos. Range is [08:00, 17:00] = 10 distinct values.
549+
// Fraction: (12:00 - 08:00) / (17:00 - 08:00) = 4/9 hours => ~4.44 => rounded to 5
550+
val t12 = DateTimeUtils.localTimeToNanos(java.time.LocalTime.of(12, 0, 0))
551+
validateEstimatedStats(
552+
Filter(LessThan(attrTime, Literal(t12, TimeType())),
553+
childStatsTestPlan(Seq(attrTime), 10L)),
554+
Seq(attrTime -> ColumnStat(distinctCount = Some(5),
555+
min = Some(tMin), max = Some(t12),
556+
nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))),
557+
expectedRowCount = 5)
558+
}
559+
560+
test("""ctime IN ( cast('09:00:00' AS TIME),
561+
cast('10:00:00' AS TIME), cast('11:00:00' AS TIME) )""") {
562+
val t09 = DateTimeUtils.localTimeToNanos(java.time.LocalTime.of(9, 0, 0))
563+
val t10 = DateTimeUtils.localTimeToNanos(java.time.LocalTime.of(10, 0, 0))
564+
val t11 = DateTimeUtils.localTimeToNanos(java.time.LocalTime.of(11, 0, 0))
565+
validateEstimatedStats(
566+
Filter(In(attrTime, Seq(Literal(t09, TimeType()), Literal(t10, TimeType()),
567+
Literal(t11, TimeType()))), childStatsTestPlan(Seq(attrTime), 10L)),
568+
Seq(attrTime -> ColumnStat(distinctCount = Some(3),
569+
min = Some(t09), max = Some(t11),
570+
nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))),
571+
expectedRowCount = 3)
572+
}
573+
526574
test("cdecimal = 0.400000000000000000") {
527575
val dec_0_40 = Decimal("0.400000000000000000")
528576
validateEstimatedStats(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
472472
val dec = Decimal("1.000000000000000000")
473473
val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08"))
474474
val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01"))
475+
val time = DateTimeUtils.localTimeToNanos(java.time.LocalTime.of(10, 30, 0))
475476
mutable.LinkedHashMap[Attribute, ColumnStat](
476477
AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = Some(1),
477478
min = Some(false), max = Some(false),
@@ -506,6 +507,9 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
506507
nullCount = Some(0), avgLen = Some(4), maxLen = Some(4)),
507508
AttributeReference("ctimestamp", TimestampType)() -> ColumnStat(distinctCount = Some(1),
508509
min = Some(timestamp), max = Some(timestamp),
510+
nullCount = Some(0), avgLen = Some(8), maxLen = Some(8)),
511+
AttributeReference("ctime", TimeType())() -> ColumnStat(distinctCount = Some(1),
512+
min = Some(time), max = Some(time),
509513
nullCount = Some(0), avgLen = Some(8), maxLen = Some(8))
510514
)
511515
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
6060
val attrTimestampNTZ = AttributeReference("ctimestamp_ntz", TimestampNTZType)()
6161
val attrYMInterval = AttributeReference("cyminterval", YearMonthIntervalType())()
6262
val attrDTInterval = AttributeReference("cdtinterval", DayTimeIntervalType())()
63+
val attrTime = AttributeReference("ctime", TimeType())()
6364

6465
val s1 = 1.toShort
6566
val s2 = 4.toShort
@@ -90,7 +91,8 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
9091
attrTimestamp -> ColumnStat(min = Some(1L), max = Some(4L)),
9192
attrTimestampNTZ -> ColumnStat(min = Some(1L), max = Some(4L)),
9293
attrYMInterval -> ColumnStat(min = Some(2), max = Some(5)),
93-
attrDTInterval -> ColumnStat(min = Some(2L), max = Some(5L))))
94+
attrDTInterval -> ColumnStat(min = Some(2L), max = Some(5L)),
95+
attrTime -> ColumnStat(min = Some(1000L), max = Some(4000L))))
9496

9597
val s3 = 2.toShort
9698
val s4 = 6.toShort
@@ -133,7 +135,10 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
133135
max = Some(8)),
134136
AttributeReference("cdttimestamp1", DayTimeIntervalType())() -> ColumnStat(
135137
min = Some(4L),
136-
max = Some(8L))))
138+
max = Some(8L)),
139+
AttributeReference("ctime1", TimeType())() -> ColumnStat(
140+
min = Some(3000L),
141+
max = Some(6000L))))
137142

138143
val child1 = StatsTestPlan(
139144
outputList = columnInfo.keys.toSeq.sortWith(_.exprId.id < _.exprId.id),
@@ -167,7 +172,8 @@ class UnionEstimationSuite extends StatsEstimationTestBase {
167172
attrTimestamp -> ColumnStat(min = Some(1L), max = Some(6L)),
168173
attrTimestampNTZ -> ColumnStat(min = Some(1L), max = Some(6L)),
169174
attrYMInterval -> ColumnStat(min = Some(2), max = Some(8)),
170-
attrDTInterval -> ColumnStat(min = Some(2L), max = Some(8L)))))
175+
attrDTInterval -> ColumnStat(min = Some(2L), max = Some(8L)),
176+
attrTime -> ColumnStat(min = Some(1000L), max = Some(6000L)))))
171177
assert(union.stats === expectedStats)
172178
}
173179

0 commit comments

Comments
 (0)