Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,11 @@ case class Cast(

override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)

final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)
final override def nodePatternsInternal(): Seq[TreePattern] = dataType match {
case _: TimestampNTZType | _: TimestampNTZNanosType |

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking (maintainability): CAST_TO_TIMESTAMP must stay set on a superset of the targets that Cast.isTimeToTimestampNTZ / isTimeToTimestampLTZ accept — otherwise ComputeCurrentTime silently stops reaching a rewrite it should perform (a CURRENT_DATE-derived TIME->TIMESTAMP left un-stabilized), and no current test would catch the drift since they're two independent match lists. Consider a short comment here cross-referencing those guards (and the rule).

Worth also recording why this keys on the target dataType and not child.dataType: node patterns are computed eagerly at construction before the child is resolved, so reading child.dataType throws (the OuterReference / CREATE FUNCTION ... RETURNS TABLE case removed in 51136ec). A future attempt to narrow the bit by also matching the source type would reintroduce that break — the target type is the only safe key.

TimestampType | _: TimestampLTZNanosType => Seq(CAST, CAST_TO_TIMESTAMP)
case _ => Seq(CAST)
}

override def contextIndependentFoldable: Boolean = {
child.contextIndependentFoldable && !Cast.needsTimeZone(child.dataType, dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,17 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
val currentDates = collection.mutable.HashMap.empty[ZoneId, Literal]
val localTimestamps = collection.mutable.HashMap.empty[ZoneId, Literal]

// The CAST bit is included so this rule can find TIME -> TIMESTAMP_NTZ and TIME ->
// TIMESTAMP_LTZ casts (which derive their date fields from CURRENT_DATE) and stabilize them
// below. CAST is a broad pattern, so this widens the rule's traversal to most plans; the
// precise `Cast.isTimeToTimestampNTZ` / `Cast.isTimeToTimestampLTZ` guards keep the rewrite
// scoped. We intentionally do not tag these casts with CURRENT_LIKE instead: inline-table
// validation treats CURRENT_LIKE as safe to defer, so tagging would let unrelated non-foldable
// NTZ/LTZ-target casts (e.g. CAST(rand() AS TIMESTAMP_NTZ)) bypass that validation (see
// SPARK-57618 and ResolveInlineTablesSuite).
// CAST_TO_TIMESTAMP is a dedicated tree-pattern bit set on Cast nodes whose target type is
// any timestamp type (NTZ or LTZ family). This lets the rule reach both TIME -> TIMESTAMP_NTZ
// and TIME -> TIMESTAMP_LTZ rewrites (which derive date fields from CURRENT_DATE) without the
// broad CAST pattern that previously widened traversal to nearly every plan. Node-level
// isTimeToTimestamp{NTZ,LTZ} guards keep rewrite semantics unchanged.
// We intentionally do NOT tag these casts with CURRENT_LIKE: inline-table validation treats
// CURRENT_LIKE as safe to defer, so tagging would let unrelated non-foldable timestamp-target
// casts (e.g. CAST(rand() AS TIMESTAMP_NTZ)) bypass validation (see SPARK-57618).
def transformCondition(treePatternbits: TreePatternBits): Boolean = {
treePatternbits.containsPattern(CURRENT_LIKE) || treePatternbits.containsPattern(CAST)
treePatternbits.containsPattern(CURRENT_LIKE) ||
treePatternbits.containsPattern(CAST_TO_TIMESTAMP)
}

plan.transformDownWithSubqueriesAndPruning(transformCondition) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ object TreePattern extends Enumeration {
val BINARY_COMPARISON: Value = Value
val CASE_WHEN: Value = Value
val CAST: Value = Value
val CAST_TO_TIMESTAMP: Value = Value
val COALESCE: Value = Value
val COMMON_EXPR_REF: Value = Value
val CONCAT: Value = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Cast, CurrentDate,
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DateType, IntegerType, TimestampLTZNanosType, TimestampNTZNanosType, TimestampNTZType, TimestampType, TimeType}
import org.apache.spark.sql.types.{DateType, IntegerType, StringType, TimestampLTZNanosType, TimestampNTZNanosType, TimestampNTZType, TimestampType, TimeType}
import org.apache.spark.unsafe.types.UTF8String

class ComputeCurrentTimeSuite extends PlanTest {
Expand Down Expand Up @@ -342,4 +343,78 @@ class ComputeCurrentTimeSuite extends PlanTest {
}
literals
}

test("SPARK-57748: TIME->TIMESTAMP cast is rewritten even with no CURRENT_LIKE node") {
val timeLit = Literal(0L, TimeType(6))
Seq(TimestampNTZType, TimestampType).foreach { target =>
val in = Project(Seq(Alias(Cast(timeLit, target), "a")()), LocalRelation())
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
val remaining = plan.expressions.flatMap(_.collect {
case c: Cast if Cast.isTimeToTimestampNTZ(c.child.dataType, c.dataType)
|| Cast.isTimeToTimestampLTZ(c.child.dataType, c.dataType) => c
})
assert(remaining.isEmpty,
s"TIME->$target cast should be rewritten with no CURRENT_LIKE present")
}
}

test("SPARK-57748: CAST_TO_TIMESTAMP tree pattern is set for NTZ target types") {
// Cast with TimestampNTZType target should contain CAST_TO_TIMESTAMP
val ntzCast = Cast(Literal(0L, TimeType(6)), TimestampNTZType)
assert(ntzCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP))
assert(ntzCast.containsPattern(TreePattern.CAST)) // existing CAST tag preserved

// Cast with TimestampNTZNanosType target should also contain CAST_TO_TIMESTAMP
val ntzNanosCast = Cast(Literal(0L, TimeType(6)), TimestampNTZNanosType(9))
assert(ntzNanosCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP))
assert(ntzNanosCast.containsPattern(TreePattern.CAST))
}

test("SPARK-57748: CAST_TO_TIMESTAMP tree pattern is NOT set for non-timestamp targets") {
// Cast to StringType should NOT contain CAST_TO_TIMESTAMP
val stringCast = Cast(Literal(0L, TimeType(6)), StringType)
assert(!stringCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP))
assert(stringCast.containsPattern(TreePattern.CAST))

// Cast to IntegerType should NOT contain CAST_TO_TIMESTAMP
val intCast = Cast(Literal("10"), IntegerType)
assert(!intCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP))
assert(intCast.containsPattern(TreePattern.CAST))
}

test("SPARK-57748: CAST_TO_TIMESTAMP tree pattern is set for LTZ targets") {
// Cast to TimestampType (LTZ micro) should contain CAST_TO_TIMESTAMP because
// ComputeCurrentTime rewrites TIME->LTZ casts via the same predicate.
val ltzCast = Cast(Literal(0L, TimeType(6)), TimestampType)
assert(ltzCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP))
assert(ltzCast.containsPattern(TreePattern.CAST))

// Cast to TimestampLTZNanosType should also contain CAST_TO_TIMESTAMP
val ltzNanosCast = Cast(Literal(0L, TimeType(6)), TimestampLTZNanosType(9))
assert(ltzNanosCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP))
assert(ltzNanosCast.containsPattern(TreePattern.CAST))
}

test("SPARK-57748: CAST_TO_TIMESTAMP is keyed on target type, not source type") {
// Source type does not matter - only the target determines the pattern bit
val fromString = Cast(Literal("2024-01-01"), TimestampNTZType)
assert(fromString.containsPattern(TreePattern.CAST_TO_TIMESTAMP))

val fromInt = Cast(Literal(42), TimestampNTZType)
assert(fromInt.containsPattern(TreePattern.CAST_TO_TIMESTAMP))

// Even with an expression child (rand()), the target type determines the bit
import org.apache.spark.sql.catalyst.expressions.Rand
val fromRand = Cast(Rand(Literal(0L)), TimestampNTZType)
assert(fromRand.containsPattern(TreePattern.CAST_TO_TIMESTAMP))
}

test("SPARK-57748: plan with non-timestamp cast only does not contain CAST_TO_TIMESTAMP") {
val timeLit = Literal(0L, TimeType(6))
val plan = Project(Seq(
Alias(Cast(timeLit, IntegerType), "a")()),
LocalRelation())
assert(!plan.containsPattern(TreePattern.CAST_TO_TIMESTAMP))
assert(plan.containsPattern(TreePattern.CAST))
}
}