Skip to content

Commit 203c319

Browse files
fix: round for float/double
1 parent 2dd4d5a commit 203c319

6 files changed

Lines changed: 235 additions & 455 deletions

File tree

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf
21+
22+
import org.apache.arrow.vector.{Float8Vector, IntVector, ValueVector}
23+
24+
import org.apache.comet.CometArrowAllocator
25+
26+
/**
27+
* `round(double, scale)` implemented by delegating to Scala's `BigDecimal(d)`, which goes through
28+
* `java.lang.Double.toString` before applying the requested scale. This matches Spark's
29+
* `RoundBase` for `DoubleType` exactly on whatever JDK the executor is running, so output stays
30+
* consistent across Java 17 / 21 even though the underlying `Double.toString` algorithm differs.
31+
*
32+
* Inputs:
33+
* - inputs(0): Float8Vector value column (length = numRows, or length 1 when literal-folded)
34+
* - inputs(1): IntVector scale, length-1 scalar (serde guarantees this)
35+
*
36+
* Output: Float8Vector, length numRows.
37+
*/
38+
class RoundDoubleUDF extends CometUDF {
39+
40+
override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = {
41+
require(inputs.length == 2, s"RoundDoubleUDF expects 2 inputs, got ${inputs.length}")
42+
val values = inputs(0).asInstanceOf[Float8Vector]
43+
val scaleVec = inputs(1).asInstanceOf[IntVector]
44+
require(
45+
scaleVec.getValueCount >= 1 && !scaleVec.isNull(0),
46+
"RoundDoubleUDF requires a non-null scalar scale")
47+
val scale = scaleVec.get(0)
48+
49+
val out = new Float8Vector("round_double", CometArrowAllocator)
50+
out.allocateNew(numRows)
51+
52+
val valueIsScalar = values.getValueCount == 1 && numRows != 1
53+
if (valueIsScalar) {
54+
if (values.isNull(0)) {
55+
var i = 0
56+
while (i < numRows) { out.setNull(i); i += 1 }
57+
} else {
58+
val rounded = RoundDoubleUDF.roundDouble(values.get(0), scale)
59+
var i = 0
60+
while (i < numRows) { out.set(i, rounded); i += 1 }
61+
}
62+
} else {
63+
var i = 0
64+
while (i < numRows) {
65+
if (values.isNull(i)) {
66+
out.setNull(i)
67+
} else {
68+
out.set(i, RoundDoubleUDF.roundDouble(values.get(i), scale))
69+
}
70+
i += 1
71+
}
72+
}
73+
out.setValueCount(numRows)
74+
out
75+
}
76+
}
77+
78+
object RoundDoubleUDF {
79+
def roundDouble(v: Double, scale: Int): Double = {
80+
if (v.isNaN || v.isInfinite) v
81+
else BigDecimal(v).setScale(scale, BigDecimal.RoundingMode.HALF_UP).doubleValue
82+
}
83+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.udf
21+
22+
import org.apache.arrow.vector.{Float4Vector, IntVector, ValueVector}
23+
24+
import org.apache.comet.CometArrowAllocator
25+
26+
/**
27+
* `round(float, scale)` implemented to mirror Spark's `RoundBase` for `FloatType`: widen to
28+
* double, build a `BigDecimal` via `java.lang.Double.toString`, apply HALF_UP at the requested
29+
* scale, then narrow back to float. The widening before BigDecimal construction is intentional:
30+
* it matches Spark and produces the same result string the JDK uses for the value.
31+
*
32+
* Inputs:
33+
* - inputs(0): Float4Vector value column (length = numRows, or length 1 when literal-folded)
34+
* - inputs(1): IntVector scale, length-1 scalar (serde guarantees this)
35+
*
36+
* Output: Float4Vector, length numRows.
37+
*/
38+
class RoundFloatUDF extends CometUDF {
39+
40+
override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = {
41+
require(inputs.length == 2, s"RoundFloatUDF expects 2 inputs, got ${inputs.length}")
42+
val values = inputs(0).asInstanceOf[Float4Vector]
43+
val scaleVec = inputs(1).asInstanceOf[IntVector]
44+
require(
45+
scaleVec.getValueCount >= 1 && !scaleVec.isNull(0),
46+
"RoundFloatUDF requires a non-null scalar scale")
47+
val scale = scaleVec.get(0)
48+
49+
val out = new Float4Vector("round_float", CometArrowAllocator)
50+
out.allocateNew(numRows)
51+
52+
val valueIsScalar = values.getValueCount == 1 && numRows != 1
53+
if (valueIsScalar) {
54+
if (values.isNull(0)) {
55+
var i = 0
56+
while (i < numRows) { out.setNull(i); i += 1 }
57+
} else {
58+
val rounded = RoundFloatUDF.roundFloat(values.get(0), scale)
59+
var i = 0
60+
while (i < numRows) { out.set(i, rounded); i += 1 }
61+
}
62+
} else {
63+
var i = 0
64+
while (i < numRows) {
65+
if (values.isNull(i)) {
66+
out.setNull(i)
67+
} else {
68+
out.set(i, RoundFloatUDF.roundFloat(values.get(i), scale))
69+
}
70+
i += 1
71+
}
72+
}
73+
out.setValueCount(numRows)
74+
out
75+
}
76+
}
77+
78+
object RoundFloatUDF {
79+
def roundFloat(v: Float, scale: Int): Float = {
80+
if (v.isNaN || v.isInfinite) v
81+
else BigDecimal(v.toDouble).setScale(scale, BigDecimal.RoundingMode.HALF_UP).floatValue
82+
}
83+
}

native/Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/spark-expr/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ twox-hash = "2.1.2"
4343
rand = { workspace = true }
4444
hex = "0.4.3"
4545
base64 = "0.22.1"
46-
bigdecimal = "0.4"
4746

4847
[dev-dependencies]
4948
arrow = {workspace = true}

0 commit comments

Comments
 (0)