Skip to content

Commit 4366eca

Browse files
kazantsev-maksimKazantsev Maksim
authored andcommitted
Chore: Used DataFusion impl of date_add and date_sub functions (apache#2473)
* Date_add and date_sub to DataFusion impl * Fix tests --------- Co-authored-by: Kazantsev Maksim <mn.kazantsev@gmail.com>
1 parent b7685d2 commit 4366eca

7 files changed

Lines changed: 15 additions & 146 deletions

File tree

native/core/src/execution/jni_api.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ use datafusion::{
4141
};
4242
use datafusion_comet_proto::spark_operator::Operator;
4343
use datafusion_spark::function::bitwise::bit_get::SparkBitGet;
44+
use datafusion_spark::function::datetime::date_add::SparkDateAdd;
45+
use datafusion_spark::function::datetime::date_sub::SparkDateSub;
4446
use datafusion_spark::function::hash::sha2::SparkSha2;
4547
use datafusion_spark::function::math::expm1::SparkExpm1;
4648
use datafusion_spark::function::string::char::CharFunc;
@@ -303,6 +305,8 @@ fn prepare_datafusion_session_context(
303305
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default()));
304306
session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default()));
305307
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default()));
308+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default()));
309+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default()));
306310

307311
// Must be the last one to override existing functions with the same name
308312
datafusion_comet_spark_expr::register_all_comet_functions(&mut session_ctx)?;

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ use crate::hash_funcs::*;
1919
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
2020
use crate::math_funcs::modulo_expr::spark_modulo;
2121
use crate::{
22-
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
23-
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
24-
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
25-
SparkBitwiseCount, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace,
22+
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
23+
spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
24+
spark_unhex, spark_unscaled_value, SparkBitwiseCount, SparkBitwiseNot, SparkDateTrunc,
25+
SparkStringSpace,
2626
};
2727
use arrow::datatypes::DataType;
2828
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -140,14 +140,6 @@ pub fn create_comet_physical_fun(
140140
let func = Arc::new(spark_isnan);
141141
make_comet_scalar_udf!("isnan", func, without data_type)
142142
}
143-
"date_add" => {
144-
let func = Arc::new(spark_date_add);
145-
make_comet_scalar_udf!("date_add", func, without data_type)
146-
}
147-
"date_sub" => {
148-
let func = Arc::new(spark_date_sub);
149-
make_comet_scalar_udf!("date_sub", func, without data_type)
150-
}
151143
"array_repeat" => {
152144
let func = Arc::new(spark_array_repeat);
153145
make_comet_scalar_udf!("array_repeat", func, without data_type)

native/spark-expr/src/datetime_funcs/date_arithmetic.rs

Lines changed: 0 additions & 101 deletions
This file was deleted.

native/spark-expr/src/datetime_funcs/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,10 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
mod date_arithmetic;
1918
mod date_trunc;
2019
mod extract_date_part;
2120
mod timestamp_trunc;
2221

23-
pub use date_arithmetic::{spark_date_add, spark_date_sub};
2422
pub use date_trunc::SparkDateTrunc;
2523
pub use extract_date_part::SparkHour;
2624
pub use extract_date_part::SparkMinute;

native/spark-expr/src/lib.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@ pub use conversion_funcs::*;
6565
pub use nondetermenistic_funcs::*;
6666

6767
pub use comet_scalar_funcs::{create_comet_physical_fun, register_all_comet_functions};
68-
pub use datetime_funcs::{
69-
spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute, SparkSecond,
70-
TimestampTruncExpr,
71-
};
68+
pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, TimestampTruncExpr};
7269
pub use error::{SparkError, SparkResult};
7370
pub use hash_funcs::*;
7471
pub use json_funcs::ToJson;

spark/src/main/scala/org/apache/comet/serde/datetime.scala

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types.{DateType, IntegerType}
2525
import org.apache.comet.CometSparkSessionExtensions.withInfo
2626
import org.apache.comet.serde.CometGetDateField.CometGetDateField
2727
import org.apache.comet.serde.ExprOuterClass.Expr
28-
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType}
28+
import org.apache.comet.serde.QueryPlanSerde._
2929

3030
private object CometGetDateField extends Enumeration {
3131
type CometGetDateField = Value
@@ -251,31 +251,9 @@ object CometSecond extends CometExpressionSerde[Second] {
251251
}
252252
}
253253

254-
object CometDateAdd extends CometExpressionSerde[DateAdd] {
255-
override def convert(
256-
expr: DateAdd,
257-
inputs: Seq[Attribute],
258-
binding: Boolean): Option[ExprOuterClass.Expr] = {
259-
val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
260-
val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
261-
val optExpr =
262-
scalarFunctionExprToProtoWithReturnType("date_add", DateType, leftExpr, rightExpr)
263-
optExprWithInfo(optExpr, expr, expr.left, expr.right)
264-
}
265-
}
254+
object CometDateAdd extends CometScalarFunction[DateAdd]("date_add")
266255

267-
object CometDateSub extends CometExpressionSerde[DateSub] {
268-
override def convert(
269-
expr: DateSub,
270-
inputs: Seq[Attribute],
271-
binding: Boolean): Option[ExprOuterClass.Expr] = {
272-
val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
273-
val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
274-
val optExpr =
275-
scalarFunctionExprToProtoWithReturnType("date_sub", DateType, leftExpr, rightExpr)
276-
optExprWithInfo(optExpr, expr, expr.left, expr.right)
277-
}
278-
}
256+
object CometDateSub extends CometScalarFunction[DateSub]("date_sub")
279257

280258
object CometTruncDate extends CometExpressionSerde[TruncDate] {
281259
override def convert(

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
247247
} else {
248248
assert(sparkErr.get.getMessage.contains("integer overflow"))
249249
}
250-
assert(cometErr.get.getMessage.contains("`NaiveDate + TimeDelta` overflowed"))
250+
assert(cometErr.get.getMessage.contains("attempt to add with overflow"))
251251
}
252252
}
253253
}
@@ -291,10 +291,11 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
291291
checkSparkMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM tbl"))
292292
if (isSpark40Plus) {
293293
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
294+
assert(cometErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
294295
} else {
295296
assert(sparkErr.get.getMessage.contains("integer overflow"))
297+
assert(cometErr.get.getMessage.contains("integer overflow"))
296298
}
297-
assert(cometErr.get.getMessage.contains("`NaiveDate - TimeDelta` overflowed"))
298299
}
299300
}
300301
}

0 commit comments

Comments
 (0)