Skip to content

Commit d51c9d6

Browse files
committed
Merge remote-tracking branch 'apache/main' into feat/bloom-filter-intermediate-buffer-compat
2 parents 16299be + 184a883 commit d51c9d6

7 files changed

Lines changed: 86 additions & 7 deletions

File tree

docs/source/contributor-guide/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@
411411
- [x] randn
412412
- [ ] random
413413
- [ ] randstr
414-
- [ ] rint
414+
- [x] rint
415415
- [x] round
416416
- [x] sec
417417
- [x] shiftleft

docs/source/user-guide/latest/expressions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ of expressions that be disabled.
174174
| Rand | `rand` |
175175
| Randn | `randn` |
176176
| Remainder | `%` |
177+
| Rint | `rint` |
177178
| Round | `round` |
178179
| Sec | `sec` |
179180
| Signum | `signum` |

native/core/src/execution/jni_api.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ use datafusion_spark::function::map::str_to_map::SparkStrToMap;
6060
use datafusion_spark::function::math::expm1::SparkExpm1;
6161
use datafusion_spark::function::math::factorial::SparkFactorial;
6262
use datafusion_spark::function::math::hex::SparkHex;
63+
use datafusion_spark::function::math::rint::SparkRint;
6364
use datafusion_spark::function::math::trigonometry::SparkCsc;
6465
use datafusion_spark::function::math::trigonometry::SparkSec;
6566
use datafusion_spark::function::math::width_bucket::SparkWidthBucket;
@@ -605,6 +606,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
605606
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkTryParseUrl::default()));
606607
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkFactorial::default()));
607608
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSec::default()));
609+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkRint::default()));
608610
}
609611

610612
/// Prepares arrow arrays for output.

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
125125
classOf[Rand] -> CometRand,
126126
classOf[Randn] -> CometRandn,
127127
classOf[Remainder] -> CometRemainder,
128+
classOf[Rint] -> CometScalarFunction("rint"),
128129
classOf[Round] -> CometRound,
129130
classOf[Sec] -> CometScalarFunction("sec"),
130131
classOf[Signum] -> CometScalarFunction("signum"),
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
-- Licensed to the Apache Software Foundation (ASF) under one
2+
-- or more contributor license agreements. See the NOTICE file
3+
-- distributed with this work for additional information
4+
-- regarding copyright ownership. The ASF licenses this file
5+
-- to you under the Apache License, Version 2.0 (the
6+
-- "License"); you may not use this file except in compliance
7+
-- with the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing,
12+
-- software distributed under the License is distributed on an
13+
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
-- KIND, either express or implied. See the License for the
15+
-- specific language governing permissions and limitations
16+
-- under the License.
17+
18+
-- ConfigMatrix: parquet.enable.dictionary=false,true
19+
20+
-- Spark's Rint extends UnaryMathExpression with inputTypes = Seq(DoubleType).
21+
-- It returns the double value closest to the argument equal to a mathematical integer
22+
-- (Java's Math.rint, IEEE 754 round-half-to-even / banker's rounding).
23+
24+
statement
25+
CREATE TABLE test_rint(v double) USING parquet
26+
27+
statement
28+
INSERT INTO test_rint VALUES
29+
(0.0),
30+
(-0.0),
31+
(1.0),
32+
(-1.0),
33+
(0.4),
34+
(0.5),
35+
(0.6),
36+
(1.5),
37+
(2.5),
38+
(3.5),
39+
(-0.4),
40+
(-0.5),
41+
(-0.6),
42+
(-1.5),
43+
(-2.5),
44+
(-3.5),
45+
(12.3456),
46+
(-12.3456),
47+
(1.7976931348623157E308),
48+
(-1.7976931348623157E308),
49+
(4.9E-324),
50+
(cast('NaN' as double)),
51+
(cast('Infinity' as double)),
52+
(cast('-Infinity' as double)),
53+
(NULL)
54+
55+
query
56+
SELECT rint(v) FROM test_rint
57+
58+
-- column with arithmetic
59+
query
60+
SELECT rint(v + 0.5) FROM test_rint

spark/src/test/scala/org/apache/comet/CometCodegenFuzzSuite.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,16 @@ class CometCodegenFuzzSuite
308308
}
309309
}
310310

311-
private def probeCardinality(accessor: String, viewName: String): Unit = {
311+
private def probeCardinality(accessor: String, dt: DataType, viewName: String): Unit = {
312+
// `Size` only supports `ArrayType` in Comet, so for `MapType` we route through `map_keys` to
313+
// reach a `Size(ArrayType)`. Spark still calls `getMap` on the column vector to extract the
314+
// keys, which is the accessor path this probe is intended to exercise.
315+
val sizeExpr = dt match {
316+
case _: MapType => s"size(map_keys($accessor))"
317+
case _ => s"cardinality($accessor)"
318+
}
312319
assertCodegenRan {
313-
checkSparkAnswerAndOperator(
314-
s"SELECT $cardinalityProbeUdf(cardinality($accessor)) FROM $viewName")
320+
checkSparkAnswerAndOperator(s"SELECT $cardinalityProbeUdf($sizeExpr) FROM $viewName")
315321
}
316322
}
317323

@@ -323,13 +329,14 @@ class CometCodegenFuzzSuite
323329
private def probeComplexColumn(field: StructField, viewName: String): Unit = {
324330
field.dataType match {
325331
case _: ArrayType | _: MapType =>
326-
probeCardinality(field.name, viewName)
332+
probeCardinality(field.name, field.dataType, viewName)
327333

328334
case st: StructType =>
329335
for (subField <- st.fields) {
330336
val accessor = s"${field.name}.${subField.name}"
331337
subField.dataType match {
332-
case _: ArrayType | _: MapType => probeCardinality(accessor, viewName)
338+
case _: ArrayType | _: MapType =>
339+
probeCardinality(accessor, subField.dataType, viewName)
333340
case dt if !isComplexType(dt) =>
334341
val udfName = s"id_${field.name}_${subField.name}"
335342
registerIdentityUdfFor(dt, udfName).foreach { _ =>

spark/src/test/scala/org/apache/comet/CometCodegenSuite.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,15 @@ class CometCodegenSuite
646646
}
647647

648648
test("ScalaUDF over Decimal(38, 10) routes through the BigDecimal slow path") {
649-
spark.udf.register("decIdLong", (d: java.math.BigDecimal) => d)
649+
// Pin the return type to Decimal(38, 10). TypeTag inference for `BigDecimal` would default to
650+
// Decimal(38, 18), and under Spark 4 ANSI the encoder's CheckOverflow throws on the 28-digit
651+
// boundary value below when rescaling 10 -> 18.
652+
spark.udf.register(
653+
"decIdLong",
654+
new UDF1[java.math.BigDecimal, java.math.BigDecimal] {
655+
override def call(d: java.math.BigDecimal): java.math.BigDecimal = d
656+
},
657+
DecimalType(38, 10))
650658
withDecimalTable(
651659
"DECIMAL(38, 10)",
652660
Seq(

0 commit comments

Comments
 (0)