Skip to content

Commit f383a0c

Browse files
authored
feat: expand date/time expression support using codegen dispatcher (#4417)
1 parent 0d5c592 commit f383a0c

24 files changed

Lines changed: 1000 additions & 109 deletions

spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
261261
val subExprsCode = ctx.subexprFunctionsCode
262262
val (cls, setup, snippet) =
263263
CometBatchKernelCodegenOutput.emitOutputWriter(boundExpr.dataType, ev.value, ctx)
264-
(cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode))
264+
(cls, setup, defaultBody(boundExpr, inputSchema, ev, snippet, subExprsCode))
265265
}
266266

267267
val typedFieldDecls = CometBatchKernelCodegenInput.emitInputFieldDecls(inputSchema)
@@ -343,6 +343,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
343343
*/
344344
private def defaultBody(
345345
boundExpr: Expression,
346+
inputSchema: Seq[ArrowColumnSpec],
346347
ev: ExprCode,
347348
writeSnippet: String,
348349
subExprsCode: String): String = {
@@ -353,9 +354,17 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim {
353354
// make this incorrect (`coalesce(null, x)` is `x`); `allNullIntolerant` rejects those.
354355
val inputOrdinals =
355356
boundExpr.collect { case b: BoundReference => b.ordinal }.distinct
357+
// Primitive Arrow vectors are wrapped in `CometPlainVector` at input-cast time, which
358+
// exposes `isNullAt(int)` rather than the raw Arrow `isNull(int)`. Pick the right method
359+
// per ordinal so the short-circuit compiles for timestamp / int / float columns too,
360+
// not just VarChar / Decimal vectors that stay as raw Arrow types.
361+
def nullCheckCall(ord: Int): String = {
362+
val method = CometBatchKernelCodegenInput.nullCheckMethod(inputSchema(ord))
363+
s"this.col$ord.$method(i)"
364+
}
356365
val nullCheck =
357366
if (inputOrdinals.isEmpty) "false"
358-
else inputOrdinals.map(ord => s"this.col$ord.isNull(i)").mkString(" || ")
367+
else inputOrdinals.map(nullCheckCall).mkString(" || ")
359368
s"""
360369
|if ($nullCheck) {
361370
| output.setNull(i);

spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,10 @@ private[codegen] object CometBatchKernelCodegenInput {
404404
/**
405405
* Java method name for the per-column null check. Primitive scalars wrapped in
406406
* [[CometPlainVector]] expose `isNullAt`; Arrow typed fields expose `isNull`. Same semantics.
407+
* Used both by `emitTypedGetters` (for the kernel's `isNullAt` switch) and by
408+
* `CometBatchKernelCodegen.defaultBody` (for the `NullIntolerant` short-circuit).
407409
*/
408-
private def nullCheckMethod(spec: ArrowColumnSpec): String = spec match {
410+
def nullCheckMethod(spec: ArrowColumnSpec): String = spec match {
409411
case sc: ScalarColumnSpec if wrapsInCometPlainVector(sc.vectorClass) => "isNullAt"
410412
case _ => "isNull"
411413
}

spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.apache.comet.serde
2121

2222
import org.apache.spark.SparkEnv
23-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Literal, ScalaUDF}
23+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Expression, Literal, ScalaUDF}
2424
import org.apache.spark.sql.types.BinaryType
2525

2626
import org.apache.comet.CometConf
@@ -45,15 +45,35 @@ import org.apache.comet.udf.codegen.CometScalaUDFCodegen
4545
*
4646
* Gated by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When disabled, plans containing a
4747
* `ScalaUDF` fall back to Spark for the enclosing operator.
48+
*
49+
* [[emitJvmCodegenDispatch]] exposes the same closure-serialize + dispatcher-proto path to other
50+
* serdes that want to keep a built-in Spark expression inside the Comet pipeline when no native
51+
* lowering is viable. See [[CometDateFormat]] for an example.
4852
*/
4953
object CometScalaUDF extends CometExpressionSerde[ScalaUDF] {
5054

51-
override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
55+
override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] =
56+
emitJvmCodegenDispatch(expr, inputs, binding)
57+
58+
/**
59+
* Bind `expr`, closure-serialize it, and emit a `JvmScalarUdf` proto routed through
60+
* [[CometScalaUDFCodegen]] so that native execution evaluates the expression inside the
61+
* Arrow-direct codegen dispatcher. The dispatcher will Janino-compile `expr.doGenCode` into a
62+
* batch kernel on first invocation per task.
63+
*
64+
* Returns `None` (with `withInfo` tagging the reason) when the dispatcher is disabled via
65+
* [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]] or when [[CometBatchKernelCodegen.canHandle]]
66+
* refuses the expression tree. Callers should treat `None` as a clean Spark-fallback signal.
67+
*/
68+
def emitJvmCodegenDispatch(
69+
expr: Expression,
70+
inputs: Seq[Attribute],
71+
binding: Boolean): Option[Expr] = {
5272
if (!CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) {
5373
withInfo(
5474
expr,
55-
s"${CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key}=false; ScalaUDF has no native path " +
56-
"so the plan falls back to Spark")
75+
s"${CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key}=false; expression has no native " +
76+
"path so the plan falls back to Spark")
5777
return None
5878
}
5979

@@ -100,3 +120,21 @@ object CometScalaUDF extends CometExpressionSerde[ScalaUDF] {
100120
.build())
101121
}
102122
}
123+
124+
/**
125+
* Convenience base for serdes that route a non-ScalaUDF Spark expression through the codegen
126+
* dispatcher. Delegates `convert` to [[CometScalaUDF.emitJvmCodegenDispatch]] and marks the
127+
* expression `Compatible()` because the dispatcher runs Spark's own `doGenCode` inside the
128+
* kernel: behavior matches Spark exactly when [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]] is
129+
* enabled, and the operator falls back to Spark cleanly when it is not.
130+
*/
131+
class CometCodegenDispatch[T <: Expression] extends CometExpressionSerde[T] {
132+
override def getSupportLevel(expr: T): SupportLevel = Compatible()
133+
// Intentionally no getCompatibleNotes override: the docs generator emits compat notes under
134+
// a heading that promises "no additional configuration required". The dispatcher flag is a
135+
// global concern documented elsewhere; tagging each expression here would contradict the
136+
// heading. When the flag is off, `convert` returns None with a clear withInfo reason that
137+
// shows up in EXPLAIN, which is the right place for that signal.
138+
override def convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[Expr] =
139+
CometScalaUDF.emitJvmCodegenDispatch(expr, inputs, binding)
140+
}

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
219219

220220
private[comet] val temporalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] =
221221
Map(
222+
classOf[AddMonths] -> CometAddMonths,
222223
classOf[ConvertTimezone] -> CometConvertTimezone,
223224
classOf[DateAdd] -> CometDateAdd,
224225
classOf[DateDiff] -> CometDateDiff,
@@ -234,12 +235,20 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
234235
classOf[LastDay] -> CometLastDay,
235236
classOf[Hour] -> CometHour,
236237
classOf[MakeDate] -> CometMakeDate,
238+
classOf[MakeTimestamp] -> CometMakeTimestamp,
239+
classOf[MicrosToTimestamp] -> CometMicrosToTimestamp,
240+
classOf[MillisToTimestamp] -> CometMillisToTimestamp,
241+
classOf[MonthsBetween] -> CometMonthsBetween,
237242
classOf[Minute] -> CometMinute,
238243
classOf[NextDay] -> CometNextDay,
239244
classOf[Second] -> CometSecond,
240245
classOf[SecondsToTimestamp] -> CometSecondsToTimestamp,
241246
classOf[TruncDate] -> CometTruncDate,
242247
classOf[TruncTimestamp] -> CometTruncTimestamp,
248+
classOf[ToUnixTimestamp] -> CometToUnixTimestamp,
249+
classOf[UnixMicros] -> CometUnixMicros,
250+
classOf[UnixMillis] -> CometUnixMillis,
251+
classOf[UnixSeconds] -> CometUnixSeconds,
243252
classOf[UnixTimestamp] -> CometUnixTimestamp,
244253
classOf[Year] -> CometYear,
245254
classOf[Month] -> CometMonth,

spark/src/main/scala/org/apache/comet/serde/datetime.scala

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ package org.apache.comet.serde
2121

2222
import java.util.Locale
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, ConvertTimezone, DateAdd, DateDiff, DateFormatClass, DateFromUnixDate, DateSub, DayOfMonth, DayOfWeek, DayOfYear, Days, FromUTCTimestamp, GetDateField, Hour, Hours, LastDay, Literal, MakeDate, Minute, Month, NextDay, Quarter, Second, SecondsToTimestamp, ToUTCTimestamp, TruncDate, TruncTimestamp, UnixDate, UnixTimestamp, WeekDay, WeekOfYear, Year}
24+
import org.apache.spark.sql.catalyst.expressions.{AddMonths, Attribute, ConvertTimezone, DateAdd, DateDiff, DateFormatClass, DateFromUnixDate, DateSub, DayOfMonth, DayOfWeek, DayOfYear, Days, FromUTCTimestamp, GetDateField, Hour, Hours, LastDay, Literal, MakeDate, MakeTimestamp, MicrosToTimestamp, MillisToTimestamp, Minute, Month, MonthsBetween, NextDay, Quarter, Second, SecondsToTimestamp, ToUnixTimestamp, ToUTCTimestamp, TruncDate, TruncTimestamp, UnixDate, UnixMicros, UnixMillis, UnixSeconds, UnixTimestamp, WeekDay, WeekOfYear, Year}
2525
import org.apache.spark.sql.internal.SQLConf
2626
import org.apache.spark.sql.types.{DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampNTZType, TimestampType}
2727
import org.apache.spark.unsafe.types.UTF8String
2828

29+
import org.apache.comet.CometConf
2930
import org.apache.comet.CometSparkSessionExtensions.withInfo
3031
import org.apache.comet.expressions.{CometCast, CometEvalMode}
3132
import org.apache.comet.serde.CometGetDateField.CometGetDateField
@@ -593,17 +594,23 @@ object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] {
593594
}
594595

595596
/**
596-
* Converts Spark DateFormatClass expression to DataFusion's to_char function.
597+
* Converts Spark `DateFormatClass` to DataFusion's `to_char` when format and timezone are
598+
* mappable, otherwise routes the expression through the Arrow-direct codegen dispatcher so that
599+
* Spark's own `DateFormatClass.doGenCode` runs inside the Comet pipeline.
597600
*
598-
* Spark uses Java SimpleDateFormat patterns while DataFusion uses strftime patterns. This
599-
* implementation supports a whitelist of common format strings that can be reliably mapped
600-
* between the two systems.
601+
* Routing:
602+
* - format is a literal in `supportedFormats` AND timezone is UTC -> native `to_char`
603+
* - format is a literal in `supportedFormats` AND timezone is non-UTC, with the per-expression
604+
* `allowIncompatible` flag set -> native `to_char` (results may differ from Spark)
605+
* - all other cases -> JVM codegen dispatcher ([[CometScalaUDF.emitJvmCodegenDispatch]]), gated
606+
* by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When that flag is disabled the operator
607+
* falls back to Spark.
601608
*/
602609
object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
603610

604611
/**
605612
* Mapping from Spark SimpleDateFormat patterns to strftime patterns. Only formats in this map
606-
* are supported.
613+
* are supported by the native path.
607614
*/
608615
val supportedFormats: Map[String, String] = Map(
609616
// Full date formats
@@ -637,66 +644,50 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] {
637644
// ISO formats
638645
"yyyy-MM-dd'T'HH:mm:ss" -> "%Y-%m-%dT%H:%M:%S")
639646

640-
override def getIncompatibleReasons(): Seq[String] = Seq(
641-
"Non-UTC timezones may produce different results than Spark")
642-
643-
override def getUnsupportedReasons(): Seq[String] = Seq(
644-
"Only the following formats are supported:" +
645-
supportedFormats.keys.toSeq.sorted
646-
.map(k => s"`$k`")
647-
.mkString("\n - ", "\n - ", ""))
647+
// Compatibility is decided inside `convert`: the native path covers a subset, and the codegen
648+
// dispatcher covers everything else when enabled. Plan-time tagging happens via `withInfo` on
649+
// the path that returns None.
650+
override def getSupportLevel(expr: DateFormatClass): SupportLevel = Compatible()
648651

649-
override def getSupportLevel(expr: DateFormatClass): SupportLevel = {
650-
// Check timezone - only UTC is fully compatible
651-
val timezone = expr.timeZoneId.getOrElse("UTC")
652-
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"
653-
654-
expr.right match {
655-
case Literal(fmt: UTF8String, _) =>
656-
val format = fmt.toString
657-
if (supportedFormats.contains(format)) {
658-
if (isUtc) {
659-
Compatible()
660-
} else {
661-
Incompatible(Some(s"Non-UTC timezone '$timezone' may produce different results"))
662-
}
663-
} else {
664-
Unsupported(
665-
Some(
666-
s"Format '$format' is not supported. Supported formats: " +
667-
supportedFormats.keys.mkString(", ")))
668-
}
669-
case _ =>
670-
Unsupported(Some("Only literal format strings are supported"))
671-
}
672-
}
652+
override def getCompatibleNotes(): Seq[String] = Seq(
653+
"Format strings in a curated allow-list run natively via DataFusion's `to_char` for UTC " +
654+
"sessions. Other format strings (including non-literal formats), as well as non-UTC " +
655+
"sessions, route through Spark's own `DateFormatClass.doGenCode` via the Arrow-direct " +
656+
"codegen dispatcher when `spark.comet.exec.scalaUDF.codegen.enabled=true`. When the " +
657+
"codegen dispatcher is disabled (default) the operator falls back to Spark in those " +
658+
"cases.")
673659

674660
override def convert(
675661
expr: DateFormatClass,
676662
inputs: Seq[Attribute],
677663
binding: Boolean): Option[ExprOuterClass.Expr] = {
678-
// Get the format string - must be a literal for us to map it
679-
val strftimeFormat = expr.right match {
680-
case Literal(fmt: UTF8String, _) =>
681-
supportedFormats.get(fmt.toString)
664+
val timezone = expr.timeZoneId.getOrElse("UTC")
665+
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"
666+
667+
val nativeFormat: Option[String] = expr.right match {
668+
case Literal(fmt: UTF8String, _) => supportedFormats.get(fmt.toString)
682669
case _ => None
683670
}
684671

685-
strftimeFormat match {
686-
case Some(format) =>
687-
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
688-
val formatExpr = exprToProtoInternal(Literal(format), inputs, binding)
689-
690-
val optExpr = scalarFunctionExprToProtoWithReturnType(
691-
"to_char",
692-
StringType,
693-
false,
694-
childExpr,
695-
formatExpr)
696-
optExprWithInfo(optExpr, expr, expr.left, expr.right)
697-
case None =>
698-
withInfo(expr, expr.left, expr.right)
699-
None
672+
val canUseNative = nativeFormat.isDefined && {
673+
isUtc || CometConf.isExprAllowIncompat(getExprConfigName(expr))
674+
}
675+
676+
if (canUseNative) {
677+
val childExpr = exprToProtoInternal(expr.left, inputs, binding)
678+
val formatExpr = exprToProtoInternal(Literal(nativeFormat.get), inputs, binding)
679+
val optExpr = scalarFunctionExprToProtoWithReturnType(
680+
"to_char",
681+
StringType,
682+
false,
683+
childExpr,
684+
formatExpr)
685+
optExprWithInfo(optExpr, expr, expr.left, expr.right)
686+
} else {
687+
// Hand the full `DateFormatClass` (with `timeZoneId` already stamped by `ResolveTimeZone`)
688+
// to the codegen dispatcher. It closure-serializes the bound tree, so non-UTC timezones
689+
// and non-whitelisted / non-literal format strings produce Spark-identical results.
690+
CometScalaUDF.emitJvmCodegenDispatch(expr, inputs, binding)
700691
}
701692
}
702693
}
@@ -780,3 +771,21 @@ object CometDays extends CometExpressionSerde[Days] {
780771
optExprWithInfo(optExpr, expr, expr.child)
781772
}
782773
}
774+
775+
object CometAddMonths extends CometCodegenDispatch[AddMonths]
776+
777+
object CometMonthsBetween extends CometCodegenDispatch[MonthsBetween]
778+
779+
object CometMakeTimestamp extends CometCodegenDispatch[MakeTimestamp]
780+
781+
object CometMicrosToTimestamp extends CometCodegenDispatch[MicrosToTimestamp]
782+
783+
object CometMillisToTimestamp extends CometCodegenDispatch[MillisToTimestamp]
784+
785+
object CometUnixSeconds extends CometCodegenDispatch[UnixSeconds]
786+
787+
object CometUnixMillis extends CometCodegenDispatch[UnixMillis]
788+
789+
object CometUnixMicros extends CometCodegenDispatch[UnixMicros]
790+
791+
object CometToUnixTimestamp extends CometCodegenDispatch[ToUnixTimestamp]
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
-- Licensed to the Apache Software Foundation (ASF) under one
2+
-- or more contributor license agreements. See the NOTICE file
3+
-- distributed with this work for additional information
4+
-- regarding copyright ownership. The ASF licenses this file
5+
-- to you under the Apache License, Version 2.0 (the
6+
-- "License"); you may not use this file except in compliance
7+
-- with the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing,
12+
-- software distributed under the License is distributed on an
13+
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
-- KIND, either express or implied. See the License for the
15+
-- specific language governing permissions and limitations
16+
-- under the License.
17+
18+
-- Routes add_months through the codegen dispatcher. Spark's own AddMonths.doGenCode
19+
-- runs inside the Janino-compiled kernel.
20+
-- Config: spark.sql.session.timeZone=America/Los_Angeles
21+
-- Config: spark.comet.exec.scalaUDF.codegen.enabled=true
22+
23+
statement
24+
CREATE TABLE test_add_months(d date, n int) USING parquet
25+
26+
statement
27+
INSERT INTO test_add_months VALUES
28+
(date('2024-01-15'), 1),
29+
(date('2024-01-31'), 1),
30+
(date('2024-12-15'), -13),
31+
(date('1970-01-01'), 0),
32+
(NULL, 1),
33+
(date('2024-06-15'), NULL)
34+
35+
query
36+
SELECT add_months(d, n) FROM test_add_months
37+
38+
query
39+
SELECT add_months(d, 12) FROM test_add_months
40+
41+
-- literal arguments
42+
query
43+
SELECT add_months(date('2024-02-29'), 12)

spark/src/test/resources/sql-tests/expressions/datetime/date_format.sql

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,27 @@
1515
-- specific language governing permissions and limitations
1616
-- under the License.
1717

18+
-- Pin the session timezone so the test exercises the non-UTC path regardless of the JVM
19+
-- default. Enable the codegen dispatcher so non-UTC and non-whitelisted formats stay inside
20+
-- Comet via Spark's own DateFormatClass.doGenCode instead of falling back to Spark.
21+
-- Config: spark.sql.session.timeZone=America/Los_Angeles
22+
-- Config: spark.comet.exec.scalaUDF.codegen.enabled=true
23+
1824
statement
1925
CREATE TABLE test_date_format(ts timestamp) USING parquet
2026

2127
statement
2228
INSERT INTO test_date_format VALUES (timestamp('2024-06-15 10:30:45')), (timestamp('1970-01-01 00:00:00')), (NULL)
2329

24-
query expect_fallback(Non-UTC timezone)
30+
query
2531
SELECT date_format(ts, 'yyyy-MM-dd') FROM test_date_format
2632

27-
query expect_fallback(Non-UTC timezone)
33+
query
2834
SELECT date_format(ts, 'HH:mm:ss') FROM test_date_format
2935

30-
query expect_fallback(Non-UTC timezone)
36+
query
3137
SELECT date_format(ts, 'yyyy-MM-dd HH:mm:ss') FROM test_date_format
3238

3339
-- literal arguments
34-
query expect_fallback(Non-UTC timezone)
40+
query
3541
SELECT date_format(timestamp('2024-06-15 10:30:45'), 'yyyy-MM-dd')

0 commit comments

Comments
 (0)