Skip to content

Commit 2a15df6

Browse files
committed
chore: wire rint
1 parent b993c0f commit 2a15df6

5 files changed

Lines changed: 65 additions & 1 deletion

File tree

docs/source/contributor-guide/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@
411411
- [x] randn
412412
- [ ] random
413413
- [ ] randstr
414-
- [ ] rint
414+
- [x] rint
415415
- [x] round
416416
- [ ] sec
417417
- [x] shiftleft

docs/source/user-guide/latest/expressions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ of expressions that be disabled.
170170
| Rand | `rand` |
171171
| Randn | `randn` |
172172
| Remainder | `%` |
173+
| Rint | `rint` |
173174
| Round | `round` |
174175
| Signum | `signum` |
175176
| Sin | `sin` |

native/core/src/execution/jni_api.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ use datafusion_spark::function::map::str_to_map::SparkStrToMap;
6060
use datafusion_spark::function::math::expm1::SparkExpm1;
6161
use datafusion_spark::function::math::factorial::SparkFactorial;
6262
use datafusion_spark::function::math::hex::SparkHex;
63+
use datafusion_spark::function::math::rint::SparkRint;
6364
use datafusion_spark::function::math::trigonometry::SparkCsc;
6465
use datafusion_spark::function::math::width_bucket::SparkWidthBucket;
6566
use datafusion_spark::function::string::char::CharFunc;
@@ -599,6 +600,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
599600
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkTryUrlDecode::default()));
600601
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkCsc::default()));
601602
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkFactorial::default()));
603+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkRint::default()));
602604
}
603605

604606
/// Prepares arrow arrays for output.

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
125125
classOf[Rand] -> CometRand,
126126
classOf[Randn] -> CometRandn,
127127
classOf[Remainder] -> CometRemainder,
128+
classOf[Rint] -> CometScalarFunction("rint"),
128129
classOf[Round] -> CometRound,
129130
classOf[Signum] -> CometScalarFunction("signum"),
130131
classOf[Sin] -> CometScalarFunction("sin"),
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
-- Licensed to the Apache Software Foundation (ASF) under one
2+
-- or more contributor license agreements. See the NOTICE file
3+
-- distributed with this work for additional information
4+
-- regarding copyright ownership. The ASF licenses this file
5+
-- to you under the Apache License, Version 2.0 (the
6+
-- "License"); you may not use this file except in compliance
7+
-- with the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing,
12+
-- software distributed under the License is distributed on an
13+
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
-- KIND, either express or implied. See the License for the
15+
-- specific language governing permissions and limitations
16+
-- under the License.
17+
18+
-- ConfigMatrix: parquet.enable.dictionary=false,true
19+
20+
-- Spark's Rint extends UnaryMathExpression with inputTypes = Seq(DoubleType).
21+
-- It returns the double value closest to the argument equal to a mathematical integer
22+
-- (Java's Math.rint, IEEE 754 round-half-to-even / banker's rounding).
23+
24+
statement
25+
CREATE TABLE test_rint(v double) USING parquet
26+
27+
statement
28+
INSERT INTO test_rint VALUES
29+
(0.0),
30+
(-0.0),
31+
(1.0),
32+
(-1.0),
33+
(0.4),
34+
(0.5),
35+
(0.6),
36+
(1.5),
37+
(2.5),
38+
(3.5),
39+
(-0.4),
40+
(-0.5),
41+
(-0.6),
42+
(-1.5),
43+
(-2.5),
44+
(-3.5),
45+
(12.3456),
46+
(-12.3456),
47+
(1.7976931348623157E308),
48+
(-1.7976931348623157E308),
49+
(4.9E-324),
50+
(cast('NaN' as double)),
51+
(cast('Infinity' as double)),
52+
(cast('-Infinity' as double)),
53+
(NULL)
54+
55+
query
56+
SELECT rint(v) FROM test_rint
57+
58+
-- column with arithmetic
59+
query
60+
SELECT rint(v + 0.5) FROM test_rint

0 commit comments

Comments
 (0)