Skip to content

Commit fc9010c

Browse files
committed
feat: JVM UDF fallback for date_format
CometDateFormat now picks between native to_char and a new DateFormatUDF that wraps Spark's DateFormatClass. Native is used when the format is a literal in the strftime-mappable whitelist and the timezone is UTC, or when spark.comet.expression.DateFormatClass.allowIncompatible is set. All other cases (non-UTC timezone, non-literal format, format outside the whitelist) now run inside Comet via the JVM UDF instead of falling back to Spark. Unlike the regexp engine config, there's no new user-facing knob: the JVM UDF is a transparent correctness fallback that delegates to Spark's own implementation, so behavior matches Spark by construction.
1 parent b993c0f commit fc9010c

3 files changed

Lines changed: 195 additions & 81 deletions

File tree

spark/src/main/scala/org/apache/comet/serde/datetime.scala

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.types.{DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampNTZType, TimestampType}
2727
import org.apache.spark.unsafe.types.UTF8String
2828

29+
import org.apache.comet.CometConf
2930
import org.apache.comet.CometSparkSessionExtensions.withInfo
3031
import org.apache.comet.expressions.{CometCast, CometEvalMode}
3132
import org.apache.comet.serde.CometGetDateField.CometGetDateField
3233
import org.apache.comet.serde.ExprOuterClass.Expr
3334
import org.apache.comet.serde.QueryPlanSerde._
35+
import org.apache.comet.udf.DateFormatUDF
3436

3537
private object CometGetDateField extends Enumeration {
3638
type CometGetDateField = Value
@@ -572,17 +574,21 @@ object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] {
572574
}
573575

574576
/**
575-
* Converts Spark DateFormatClass expression to DataFusion's to_char function.
577+
* Converts Spark DateFormatClass expression to DataFusion's to_char function when the format and
578+
* timezone are mappable; otherwise emits a JvmScalarUdf that delegates to Spark's own
579+
* `DateFormatClass` so that any format / timezone combination remains supported.
576580
*
577-
* Spark uses Java SimpleDateFormat patterns while DataFusion uses strftime patterns. This
578-
* implementation supports a whitelist of common format strings that can be reliably mapped
579-
* between the two systems.
581+
* Routing:
582+
* - format is a literal in `supportedFormats` AND timezone is UTC -> native to_char
583+
* - format is a literal in `supportedFormats` AND timezone is non-UTC, with the per-expression
584+
* allowIncompatible flag set -> native to_char (results may differ from Spark)
585+
* - all other cases -> JVM UDF (`org.apache.comet.udf.DateFormatUDF`)
580586
*/
581587
object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
582588

583589
/**
584590
* Mapping from Spark SimpleDateFormat patterns to strftime patterns. Only formats in this map
585-
* are supported.
591+
* are supported by the native path.
586592
*/
587593
val supportedFormats: Map[String, String] = Map(
588594
// Full date formats
@@ -616,67 +622,70 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
616622
// ISO formats
617623
"yyyy-MM-dd'T'HH:mm:ss" -> "%Y-%m-%dT%H:%M:%S")
618624

619-
override def getIncompatibleReasons(): Seq[String] = Seq(
620-
"Non-UTC timezones may produce different results than Spark")
621-
622-
override def getUnsupportedReasons(): Seq[String] = Seq(
623-
"Only the following formats are supported:" +
624-
supportedFormats.keys.toSeq.sorted
625-
.map(k => s"`$k`")
626-
.mkString("\n - ", "\n - ", ""))
625+
// The JVM UDF covers every case that the native path cannot, so the expression is always
626+
// emittable. Compatibility decisions happen inside `convert`.
627+
override def getSupportLevel(expr: DateFormatClass): SupportLevel = Compatible()
627628

628-
override def getSupportLevel(expr: DateFormatClass): SupportLevel = {
629-
// Check timezone - only UTC is fully compatible
629+
override def convert(
630+
expr: DateFormatClass,
631+
inputs: Seq[Attribute],
632+
binding: Boolean): Option[ExprOuterClass.Expr] = {
630633
val timezone = expr.timeZoneId.getOrElse("UTC")
631634
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"
632635

633-
expr.right match {
634-
case Literal(fmt: UTF8String, _) =>
635-
val format = fmt.toString
636-
if (supportedFormats.contains(format)) {
637-
if (isUtc) {
638-
Compatible()
639-
} else {
640-
Incompatible(Some(s"Non-UTC timezone '$timezone' may produce different results"))
641-
}
642-
} else {
643-
Unsupported(
644-
Some(
645-
s"Format '$format' is not supported. Supported formats: " +
646-
supportedFormats.keys.mkString(", ")))
647-
}
648-
case _ =>
649-
Unsupported(Some("Only literal format strings are supported"))
636+
val nativeFormat: Option[String] = expr.right match {
637+
case Literal(fmt: UTF8String, _) => supportedFormats.get(fmt.toString)
638+
case _ => None
639+
}
640+
641+
val canUseNative = nativeFormat.isDefined && {
642+
isUtc || CometConf.isExprAllowIncompat(getExprConfigName(expr))
643+
}
644+
645+
if (canUseNative) {
646+
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
647+
val formatExpr = exprToProtoInternal(Literal(nativeFormat.get), inputs, binding)
648+
val optExpr = scalarFunctionExprToProtoWithReturnType(
649+
"to_char",
650+
StringType,
651+
false,
652+
childExpr,
653+
formatExpr)
654+
optExprWithInfo(optExpr, expr, expr.left, expr.right)
655+
} else {
656+
convertViaJvmUdf(expr, timezone, inputs, binding)
650657
}
651658
}
652659

653-
override def convert(
660+
private def convertViaJvmUdf(
654661
expr: DateFormatClass,
662+
timezone: String,
655663
inputs: Seq[Attribute],
656664
binding: Boolean): Option[ExprOuterClass.Expr] = {
657-
// Get the format string - must be a literal for us to map it
658-
val strftimeFormat = expr.right match {
659-
case Literal(fmt: UTF8String, _) =>
660-
supportedFormats.get(fmt.toString)
661-
case _ => None
665+
val tsProto = exprToProtoInternal(expr.left, inputs, binding)
666+
val fmtProto = exprToProtoInternal(expr.right, inputs, binding)
667+
val tzProto = exprToProtoInternal(Literal(timezone), inputs, binding)
668+
if (tsProto.isEmpty || fmtProto.isEmpty || tzProto.isEmpty) {
669+
withInfo(expr, expr.left, expr.right)
670+
return None
662671
}
663-
664-
strftimeFormat match {
665-
case Some(format) =>
666-
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
667-
val formatExpr = exprToProtoInternal(Literal(format), inputs, binding)
668-
669-
val optExpr = scalarFunctionExprToProtoWithReturnType(
670-
"to_char",
671-
StringType,
672-
false,
673-
childExpr,
674-
formatExpr)
675-
optExprWithInfo(optExpr, expr, expr.left, expr.right)
676-
case None =>
677-
withInfo(expr, expr.left, expr.right)
678-
None
672+
val returnType = serializeDataType(StringType).getOrElse {
673+
withInfo(expr, "Failed to serialize StringType return type")
674+
return None
679675
}
676+
val udfBuilder = ExprOuterClass.JvmScalarUdf
677+
.newBuilder()
678+
.setClassName(classOf[DateFormatUDF].getName)
679+
.addArgs(tsProto.get)
680+
.addArgs(fmtProto.get)
681+
.addArgs(tzProto.get)
682+
.setReturnType(returnType)
683+
.setReturnNullable(expr.nullable)
684+
Some(
685+
ExprOuterClass.Expr
686+
.newBuilder()
687+
.setJvmScalarUdf(udfBuilder.build())
688+
.build())
680689
}
681690
}
682691

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf
21+
22+
import java.nio.charset.StandardCharsets.UTF_8
23+
import java.util.concurrent.ConcurrentHashMap
24+
25+
import org.apache.arrow.vector.{TimeStampMicroTZVector, TimeStampMicroVector, ValueVector, VarCharVector}
26+
import org.apache.spark.sql.catalyst.expressions.{BoundReference, DateFormatClass, Literal}
27+
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
28+
import org.apache.spark.sql.types.{StringType, TimestampType}
29+
import org.apache.spark.unsafe.types.UTF8String
30+
31+
import org.apache.comet.CometArrowAllocator
32+
33+
/**
34+
* `date_format(timestamp, format)` implemented by delegating to Spark's `DateFormatClass`.
35+
*
36+
* Used as the JVM fallback when CometDateFormat cannot push to native (non-UTC timezone, format
37+
* outside the strftime-mappable whitelist, non-literal format string).
38+
*
39+
* Inputs:
40+
* - inputs(0): TimeStampMicro[TZ]Vector - timestamp column (microseconds since epoch)
41+
* - inputs(1): VarCharVector - format string; length-1 if literal, else per-row
42+
* - inputs(2): VarCharVector - session timezone id (length-1 scalar)
43+
*
44+
* Output: VarCharVector of length `numRows`.
45+
*/
46+
class DateFormatUDF extends CometUDF {
47+
48+
// Cache one DateFormatClass per (format, timezone). Constructing it with a Literal format makes
49+
// its `formatterOption` lazy-val resolve to Some(formatter), so subsequent eval calls reuse the
50+
// formatter instead of rebuilding it per row.
51+
private val cache = new ConcurrentHashMap[(String, String), DateFormatClass]()
52+
53+
private def lookup(formatStr: String, tzId: String): DateFormatClass =
54+
cache.computeIfAbsent(
55+
(formatStr, tzId),
56+
{ case (f, tz) =>
57+
DateFormatClass(
58+
BoundReference(0, TimestampType, nullable = true),
59+
Literal(UTF8String.fromString(f), StringType),
60+
Some(tz))
61+
})
62+
63+
override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = {
64+
require(
65+
inputs.length == 3,
66+
s"DateFormatUDF expects 3 inputs (timestamp, format, timezone), got ${inputs.length}")
67+
val tsVec = inputs(0)
68+
val fmtVec = inputs(1).asInstanceOf[VarCharVector]
69+
val tzVec = inputs(2).asInstanceOf[VarCharVector]
70+
require(
71+
tzVec.getValueCount >= 1 && !tzVec.isNull(0),
72+
"DateFormatUDF requires a non-null scalar timezone")
73+
74+
val tzId = new String(tzVec.get(0), UTF_8)
75+
val fmtScalar = fmtVec.getValueCount == 1
76+
// For scalar format the format never varies, so resolve the DateFormatClass once and reuse it
77+
// across every row instead of doing a Tuple2 allocation + HashMap lookup per row.
78+
val scalarDf: DateFormatClass =
79+
if (fmtScalar && !fmtVec.isNull(0)) lookup(new String(fmtVec.get(0), UTF_8), tzId)
80+
else null
81+
82+
val getMicros: Int => Long = tsVec match {
83+
case t: TimeStampMicroTZVector => i => t.get(i)
84+
case t: TimeStampMicroVector => i => t.get(i)
85+
case other =>
86+
throw new RuntimeException(
87+
s"DateFormatUDF: unsupported timestamp vector ${other.getClass.getName}")
88+
}
89+
90+
val out = new VarCharVector("date_format_result", CometArrowAllocator)
91+
out.allocateNew(numRows)
92+
93+
val row = new GenericInternalRow(1)
94+
95+
var i = 0
96+
while (i < numRows) {
97+
val fmtIdx = if (fmtScalar) 0 else i
98+
val result: AnyRef =
99+
if (tsVec.isNull(i) || fmtVec.isNull(fmtIdx)) null
100+
else {
101+
val df =
102+
if (scalarDf != null) scalarDf
103+
else lookup(new String(fmtVec.get(i), UTF_8), tzId)
104+
row.update(0, getMicros(i))
105+
df.eval(row).asInstanceOf[AnyRef]
106+
}
107+
if (result == null) out.setNull(i)
108+
else out.setSafe(i, result.asInstanceOf[UTF8String].getBytes)
109+
i += 1
110+
}
111+
out.setValueCount(numRows)
112+
out
113+
}
114+
}

spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,8 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
214214
}
215215

216216
test("date_format - timestamp_ntz input") {
217-
// TimestampNTZ is timezone-independent, so date_format should produce the same
218-
// formatted string regardless of session timezone. Comet currently only runs this
219-
// natively for UTC; for non-UTC it falls back to Spark. We verify correctness
220-
// (matching Spark's output) in all cases.
217+
// TimestampNTZ is timezone-independent, so date_format must produce the same string
218+
// regardless of session timezone.
221219
val r = new Random(42)
222220
val ntzSchema = StructType(Seq(StructField("ts_ntz", DataTypes.TimestampNTZType, true)))
223221
val ntzDF = FuzzDataGenerator.generateDataFrame(r, spark, ntzSchema, 100, DataGenOptions())
@@ -227,14 +225,8 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
227225
for (tz <- crossTimezones) {
228226
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
229227
for (format <- supportedFormats) {
230-
if (tz == "UTC") {
231-
checkSparkAnswerAndOperator(
232-
s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz")
233-
} else {
234-
// Non-UTC falls back to Spark but should still produce correct results
235-
checkSparkAnswer(
236-
s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz")
237-
}
228+
checkSparkAnswerAndOperator(
229+
s"SELECT ts_ntz, date_format(ts_ntz, '$format') from ntz_tbl order by ts_ntz")
238230
}
239231
}
240232
}
@@ -476,45 +468,44 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
476468
}
477469
}
478470

479-
test("date_format unsupported format falls back to Spark") {
471+
test("date_format unsupported format runs via JVM UDF inside Comet") {
480472
createTimestampTestData.createOrReplaceTempView("tbl")
481473

482474
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
483-
// Unsupported format string
484-
checkSparkAnswerAndFallbackReason(
485-
"SELECT c0, date_format(c0, 'yyyy-MM-dd EEEE') from tbl order by c0",
486-
"Format 'yyyy-MM-dd EEEE' is not supported")
475+
checkSparkAnswerAndOperator(
476+
"SELECT c0, date_format(c0, 'yyyy-MM-dd EEEE') from tbl order by c0")
487477
}
488478
}
489479

490-
test("date_format with non-UTC timezone falls back to Spark") {
480+
test("date_format with non-UTC timezone runs via JVM UDF inside Comet") {
491481
createTimestampTestData.createOrReplaceTempView("tbl")
492482

493483
val nonUtcTimezones =
494484
Seq("America/New_York", "America/Los_Angeles", "Europe/London", "Asia/Tokyo")
495485

496486
for (tz <- nonUtcTimezones) {
497487
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
498-
// Non-UTC timezones should fall back to Spark as Incompatible
499-
checkSparkAnswerAndFallbackReason(
500-
"SELECT c0, date_format(c0, 'yyyy-MM-dd HH:mm:ss') from tbl order by c0",
501-
s"Non-UTC timezone '$tz' may produce different results")
488+
checkSparkAnswerAndOperator(
489+
"SELECT c0, date_format(c0, 'yyyy-MM-dd HH:mm:ss') from tbl order by c0")
502490
}
503491
}
504492
}
505493

506-
test("date_format with non-UTC timezone works when allowIncompatible is enabled") {
494+
test("date_format with non-UTC timezone takes native path when allowIncompatible is enabled") {
507495
createTimestampTestData.createOrReplaceTempView("tbl")
508496

509497
val nonUtcTimezones = Seq("America/New_York", "Europe/London", "Asia/Tokyo")
510498

511499
for (tz <- nonUtcTimezones) {
512500
withSQLConf(
513501
SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz,
514-
"spark.comet.expr.DateFormatClass.allowIncompatible" -> "true") {
515-
// With allowIncompatible enabled, Comet will execute the expression
516-
// Results may differ from Spark but should not throw errors
517-
checkSparkAnswer("SELECT c0, date_format(c0, 'yyyy-MM-dd') from tbl order by c0")
502+
"spark.comet.expression.DateFormatClass.allowIncompatible" -> "true") {
503+
// Native to_char results may diverge from Spark for non-UTC timezones (the reason the
504+
// JVM UDF is the default), so we only check that execution stays inside Comet. ORDER BY
505+
// is omitted to keep the plan free of AQEShuffleRead.
506+
val df = sql("SELECT c0, date_format(c0, 'yyyy-MM-dd') from tbl")
507+
df.collect()
508+
checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan))
518509
}
519510
}
520511
}

0 commit comments

Comments
 (0)