Skip to content

Commit 267ad4c

Browse files
authored
feat: add support for datediff expression (#3145)
1 parent f63c6a6 commit 267ad4c

8 files changed

Lines changed: 166 additions & 4 deletions

File tree

docs/source/user-guide/latest/configs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ These settings can be used to determine which parts of the plan are accelerated
234234
| `spark.comet.expression.CreateArray.enabled` | Enable Comet acceleration for `CreateArray` | true |
235235
| `spark.comet.expression.CreateNamedStruct.enabled` | Enable Comet acceleration for `CreateNamedStruct` | true |
236236
| `spark.comet.expression.DateAdd.enabled` | Enable Comet acceleration for `DateAdd` | true |
237+
| `spark.comet.expression.DateDiff.enabled` | Enable Comet acceleration for `DateDiff` | true |
237238
| `spark.comet.expression.DateFormatClass.enabled` | Enable Comet acceleration for `DateFormatClass` | true |
238239
| `spark.comet.expression.DateSub.enabled` | Enable Comet acceleration for `DateSub` | true |
239240
| `spark.comet.expression.DayOfMonth.enabled` | Enable Comet acceleration for `DayOfMonth` | true |

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo;
2222
use crate::{
2323
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
2424
spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
25-
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, SparkSizeFunc,
26-
SparkStringSpace,
25+
spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateDiff, SparkDateTrunc,
26+
SparkSizeFunc, SparkStringSpace,
2727
};
2828
use arrow::datatypes::DataType;
2929
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -192,6 +192,7 @@ pub fn create_comet_physical_fun_with_eval_mode(
192192
fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
193193
vec![
194194
Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())),
195+
Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())),
195196
Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
196197
Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())),
197198
Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())),
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, Date32Array, Int32Array};
19+
use arrow::compute::kernels::arity::binary;
20+
use arrow::datatypes::DataType;
21+
use datafusion::common::{utils::take_function_args, DataFusionError, Result};
22+
use datafusion::logical_expr::{
23+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24+
};
25+
use std::any::Any;
26+
use std::sync::Arc;
27+
28+
/// Spark-compatible date_diff function.
29+
/// Returns the number of days from startDate to endDate (endDate - startDate).
30+
#[derive(Debug, PartialEq, Eq, Hash)]
31+
pub struct SparkDateDiff {
32+
signature: Signature,
33+
aliases: Vec<String>,
34+
}
35+
36+
impl SparkDateDiff {
37+
pub fn new() -> Self {
38+
Self {
39+
signature: Signature::exact(
40+
vec![DataType::Date32, DataType::Date32],
41+
Volatility::Immutable,
42+
),
43+
aliases: vec!["datediff".to_string()],
44+
}
45+
}
46+
}
47+
48+
impl Default for SparkDateDiff {
49+
fn default() -> Self {
50+
Self::new()
51+
}
52+
}
53+
54+
impl ScalarUDFImpl for SparkDateDiff {
55+
fn as_any(&self) -> &dyn Any {
56+
self
57+
}
58+
59+
fn name(&self) -> &str {
60+
"date_diff"
61+
}
62+
63+
fn signature(&self) -> &Signature {
64+
&self.signature
65+
}
66+
67+
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
68+
Ok(DataType::Int32)
69+
}
70+
71+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
72+
let [end_date, start_date] = take_function_args(self.name(), args.args)?;
73+
74+
// Convert scalars to arrays for uniform processing
75+
let end_arr = end_date.into_array(1)?;
76+
let start_arr = start_date.into_array(1)?;
77+
78+
let end_date_array = end_arr
79+
.as_any()
80+
.downcast_ref::<Date32Array>()
81+
.ok_or_else(|| {
82+
DataFusionError::Execution("date_diff expects Date32Array for end_date".to_string())
83+
})?;
84+
85+
let start_date_array = start_arr
86+
.as_any()
87+
.downcast_ref::<Date32Array>()
88+
.ok_or_else(|| {
89+
DataFusionError::Execution(
90+
"date_diff expects Date32Array for start_date".to_string(),
91+
)
92+
})?;
93+
94+
// Date32 stores days since epoch, so difference is just subtraction
95+
let result: Int32Array =
96+
binary(end_date_array, start_date_array, |end, start| end - start)?;
97+
98+
Ok(ColumnarValue::Array(Arc::new(result)))
99+
}
100+
101+
fn aliases(&self) -> &[String] {
102+
&self.aliases
103+
}
104+
}

native/spark-expr/src/datetime_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod date_diff;
1819
mod date_trunc;
1920
mod extract_date_part;
2021
mod timestamp_trunc;
2122

23+
pub use date_diff::SparkDateDiff;
2224
pub use date_trunc::SparkDateTrunc;
2325
pub use extract_date_part::SparkHour;
2426
pub use extract_date_part::SparkMinute;

native/spark-expr/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ pub use comet_scalar_funcs::{
6969
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
7070
register_all_comet_functions,
7171
};
72-
pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr};
72+
pub use datetime_funcs::{
73+
SparkDateDiff, SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr,
74+
};
7375
pub use error::{SparkError, SparkResult};
7476
pub use hash_funcs::*;
7577
pub use json_funcs::{FromJson, ToJson};

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
185185

186186
private val temporalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
187187
classOf[DateAdd] -> CometDateAdd,
188+
classOf[DateDiff] -> CometDateDiff,
188189
classOf[DateFormatClass] -> CometDateFormat,
189190
classOf[DateSub] -> CometDateSub,
190191
classOf[UnixDate] -> CometUnixDate,

spark/src/main/scala/org/apache/comet/serde/datetime.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import java.util.Locale
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateFormatClass, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, Literal, Minute, Month, Quarter, Second, TruncDate, TruncTimestamp, UnixDate, WeekDay, WeekOfYear, Year}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateDiff, DateFormatClass, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, Literal, Minute, Month, Quarter, Second, TruncDate, TruncTimestamp, UnixDate, WeekDay, WeekOfYear, Year}
2525
import org.apache.spark.sql.types.{DateType, IntegerType, StringType}
2626
import org.apache.spark.unsafe.types.UTF8String
2727

@@ -258,6 +258,8 @@ object CometDateAdd extends CometScalarFunction[DateAdd]("date_add")
258258

259259
object CometDateSub extends CometScalarFunction[DateSub]("date_sub")
260260

261+
object CometDateDiff extends CometScalarFunction[DateDiff]("date_diff")
262+
261263
/**
262264
* Converts a date to the number of days since Unix epoch (1970-01-01). Since dates are internally
263265
* stored as days since epoch, this is a simple cast to integer.

spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,55 @@ class CometTemporalExpressionSuite extends CometTestBase with AdaptiveSparkPlanH
123123
FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions())
124124
}
125125

126+
test("datediff") {
127+
val r = new Random(42)
128+
val schema = StructType(
129+
Seq(
130+
StructField("c0", DataTypes.DateType, true),
131+
StructField("c1", DataTypes.DateType, true)))
132+
val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, DataGenOptions())
133+
df.createOrReplaceTempView("tbl")
134+
135+
// Basic test with random dates
136+
checkSparkAnswerAndOperator("SELECT c0, c1, datediff(c0, c1) FROM tbl ORDER BY c0, c1")
137+
138+
// Disable constant folding to ensure literal expressions are executed by Comet
139+
withSQLConf(
140+
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
141+
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {
142+
// Test positive difference (end date > start date)
143+
checkSparkAnswerAndOperator("SELECT datediff(DATE('2009-07-31'), DATE('2009-07-30'))")
144+
145+
// Test negative difference (end date < start date)
146+
checkSparkAnswerAndOperator("SELECT datediff(DATE('2009-07-30'), DATE('2009-07-31'))")
147+
148+
// Test same dates (should be 0)
149+
checkSparkAnswerAndOperator("SELECT datediff(DATE('2009-07-30'), DATE('2009-07-30'))")
150+
151+
// Test larger date differences
152+
checkSparkAnswerAndOperator("SELECT datediff(DATE('2024-01-01'), DATE('2020-01-01'))")
153+
154+
// Test null handling
155+
checkSparkAnswerAndOperator("SELECT datediff(NULL, DATE('2009-07-30'))")
156+
checkSparkAnswerAndOperator("SELECT datediff(DATE('2009-07-30'), NULL)")
157+
158+
// Test leap year edge cases
159+
// 1900 was NOT a leap year (divisible by 100 but not 400)
160+
// 2000 WAS a leap year (divisible by 400)
161+
// So Feb 27 to Mar 1 spans different number of days:
162+
// 1900: 2 days (Feb 28, Mar 1)
163+
// 2000: 3 days (Feb 28, Feb 29, Mar 1)
164+
checkSparkAnswerAndOperator("SELECT datediff(DATE('1900-03-01'), DATE('1900-02-27'))")
165+
checkSparkAnswerAndOperator("SELECT datediff(DATE('2000-03-01'), DATE('2000-02-27'))")
166+
167+
// Additional leap year tests
168+
// 2004 was a leap year (divisible by 4, not by 100)
169+
checkSparkAnswerAndOperator("SELECT datediff(DATE('2004-03-01'), DATE('2004-02-28'))")
170+
// 2100 will NOT be a leap year (divisible by 100 but not 400)
171+
checkSparkAnswerAndOperator("SELECT datediff(DATE('2100-03-01'), DATE('2100-02-28'))")
172+
}
173+
}
174+
126175
test("date_format with timestamp column") {
127176
// Filter out formats with embedded quotes that need special handling
128177
val supportedFormats = CometDateFormat.supportedFormats.keys.toSeq

0 commit comments

Comments
 (0)