diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 96ebe62b76ad8..84d5813bc77b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -689,7 +689,11 @@ case class Cast( override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild) - final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST) + final override def nodePatternsInternal(): Seq[TreePattern] = dataType match { + case _: TimestampNTZType | _: TimestampNTZNanosType | + TimestampType | _: TimestampLTZNanosType => Seq(CAST, CAST_TO_TIMESTAMP) + case _ => Seq(CAST) + } override def contextIndependentFoldable: Boolean = { child.contextIndependentFoldable && !Cast.needsTimeZone(child.dataType, dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 437b67d0855d3..c8c00a3fa13ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -120,16 +120,17 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { val currentDates = collection.mutable.HashMap.empty[ZoneId, Literal] val localTimestamps = collection.mutable.HashMap.empty[ZoneId, Literal] - // The CAST bit is included so this rule can find TIME -> TIMESTAMP_NTZ and TIME -> - // TIMESTAMP_LTZ casts (which derive their date fields from CURRENT_DATE) and stabilize them - // below. CAST is a broad pattern, so this widens the rule's traversal to most plans; the - // precise `Cast.isTimeToTimestampNTZ` / `Cast.isTimeToTimestampLTZ` guards keep the rewrite - // scoped. We intentionally do not tag these casts with CURRENT_LIKE instead: inline-table - // validation treats CURRENT_LIKE as safe to defer, so tagging would let unrelated non-foldable - // NTZ/LTZ-target casts (e.g. CAST(rand() AS TIMESTAMP_NTZ)) bypass that validation (see - // SPARK-57618 and ResolveInlineTablesSuite). + // CAST_TO_TIMESTAMP is a dedicated tree-pattern bit set on Cast nodes whose target type is + // any timestamp type (NTZ or LTZ family). This lets the rule reach both TIME -> TIMESTAMP_NTZ + // and TIME -> TIMESTAMP_LTZ rewrites (which derive date fields from CURRENT_DATE) without the + // broad CAST pattern that previously widened traversal to nearly every plan. Node-level + // isTimeToTimestamp{NTZ,LTZ} guards keep rewrite semantics unchanged. + // We intentionally do NOT tag these casts with CURRENT_LIKE: inline-table validation treats + // CURRENT_LIKE as safe to defer, so tagging would let unrelated non-foldable timestamp-target + // casts (e.g. CAST(rand() AS TIMESTAMP_NTZ)) bypass validation (see SPARK-57618). def transformCondition(treePatternbits: TreePatternBits): Boolean = { - treePatternbits.containsPattern(CURRENT_LIKE) || treePatternbits.containsPattern(CAST) + treePatternbits.containsPattern(CURRENT_LIKE) || + treePatternbits.containsPattern(CAST_TO_TIMESTAMP) } plan.transformDownWithSubqueriesAndPruning(transformCondition) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 173df28e2b248..94b4666a88a8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -41,6 +41,7 @@ object TreePattern extends Enumeration { val BINARY_COMPARISON: Value = Value val CASE_WHEN: Value = Value val CAST: Value = Value + val CAST_TO_TIMESTAMP: Value = Value val COALESCE: Value = Value val COMMON_EXPR_REF: Value = Value val CONCAT: Value = Value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala index a94bb6f6c5c1b..be24f9c9f01f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -28,9 +28,10 @@ import org.apache.spark.sql.catalyst.expressions.{Add, Alias, Cast, CurrentDate, import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.trees.TreePattern import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DateType, IntegerType, TimestampLTZNanosType, TimestampNTZNanosType, TimestampNTZType, TimestampType, TimeType} +import org.apache.spark.sql.types.{DateType, IntegerType, StringType, TimestampLTZNanosType, TimestampNTZNanosType, TimestampNTZType, TimestampType, TimeType} import org.apache.spark.unsafe.types.UTF8String class ComputeCurrentTimeSuite extends PlanTest { @@ -342,4 +343,78 @@ class ComputeCurrentTimeSuite extends PlanTest { } literals } + + test("SPARK-57748: TIME->TIMESTAMP cast is rewritten even with no CURRENT_LIKE node") { + val timeLit = Literal(0L, TimeType(6)) + Seq(TimestampNTZType, TimestampType).foreach { target => + val in = Project(Seq(Alias(Cast(timeLit, target), "a")()), LocalRelation()) + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val remaining = plan.expressions.flatMap(_.collect { + case c: Cast if Cast.isTimeToTimestampNTZ(c.child.dataType, c.dataType) + || Cast.isTimeToTimestampLTZ(c.child.dataType, c.dataType) => c + }) + assert(remaining.isEmpty, + s"TIME->$target cast should be rewritten with no CURRENT_LIKE present") + } + } + + test("SPARK-57748: CAST_TO_TIMESTAMP tree pattern is set for NTZ target types") { + // Cast with TimestampNTZType target should contain CAST_TO_TIMESTAMP + val ntzCast = Cast(Literal(0L, TimeType(6)), TimestampNTZType) + assert(ntzCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP)) + assert(ntzCast.containsPattern(TreePattern.CAST)) // existing CAST tag preserved + + // Cast with TimestampNTZNanosType target should also contain CAST_TO_TIMESTAMP + val ntzNanosCast = Cast(Literal(0L, TimeType(6)), TimestampNTZNanosType(9)) + assert(ntzNanosCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP)) + assert(ntzNanosCast.containsPattern(TreePattern.CAST)) + } + + test("SPARK-57748: CAST_TO_TIMESTAMP tree pattern is NOT set for non-timestamp targets") { + // Cast to StringType should NOT contain CAST_TO_TIMESTAMP + val stringCast = Cast(Literal(0L, TimeType(6)), StringType) + assert(!stringCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP)) + assert(stringCast.containsPattern(TreePattern.CAST)) + + // Cast to IntegerType should NOT contain CAST_TO_TIMESTAMP + val intCast = Cast(Literal("10"), IntegerType) + assert(!intCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP)) + assert(intCast.containsPattern(TreePattern.CAST)) + } + + test("SPARK-57748: CAST_TO_TIMESTAMP tree pattern is set for LTZ targets") { + // Cast to TimestampType (LTZ micro) should contain CAST_TO_TIMESTAMP because + // ComputeCurrentTime rewrites TIME->LTZ casts via the same predicate. + val ltzCast = Cast(Literal(0L, TimeType(6)), TimestampType) + assert(ltzCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP)) + assert(ltzCast.containsPattern(TreePattern.CAST)) + + // Cast to TimestampLTZNanosType should also contain CAST_TO_TIMESTAMP + val ltzNanosCast = Cast(Literal(0L, TimeType(6)), TimestampLTZNanosType(9)) + assert(ltzNanosCast.containsPattern(TreePattern.CAST_TO_TIMESTAMP)) + assert(ltzNanosCast.containsPattern(TreePattern.CAST)) + } + + test("SPARK-57748: CAST_TO_TIMESTAMP is keyed on target type, not source type") { + // Source type does not matter - only the target determines the pattern bit + val fromString = Cast(Literal("2024-01-01"), TimestampNTZType) + assert(fromString.containsPattern(TreePattern.CAST_TO_TIMESTAMP)) + + val fromInt = Cast(Literal(42), TimestampNTZType) + assert(fromInt.containsPattern(TreePattern.CAST_TO_TIMESTAMP)) + + // Even with an expression child (rand()), the target type determines the bit + import org.apache.spark.sql.catalyst.expressions.Rand + val fromRand = Cast(Rand(Literal(0L)), TimestampNTZType) + assert(fromRand.containsPattern(TreePattern.CAST_TO_TIMESTAMP)) + } + + test("SPARK-57748: plan with non-timestamp cast only does not contain CAST_TO_TIMESTAMP") { + val timeLit = Literal(0L, TimeType(6)) + val plan = Project(Seq( + Alias(Cast(timeLit, IntegerType), "a")()), + LocalRelation()) + assert(!plan.containsPattern(TreePattern.CAST_TO_TIMESTAMP)) + assert(plan.containsPattern(TreePattern.CAST)) + } }