Skip to content

Commit 3470e5f

Browse files
fix: round for float/double
1 parent 203c319 commit 3470e5f

1 file changed

Lines changed: 49 additions & 61 deletions

File tree

  • native/spark-expr/src/math_funcs

native/spark-expr/src/math_funcs/round.rs

Lines changed: 49 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@ use crate::arithmetic_overflow_error;
1919
use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
2020
use arrow::array::{Array, ArrowNativeTypeOp};
2121
use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
22-
use arrow::datatypes::{DataType, Field};
22+
use arrow::datatypes::DataType;
2323
use arrow::error::ArrowError;
24-
use datafusion::common::config::ConfigOptions;
2524
use datafusion::common::{exec_err, internal_err, DataFusionError, ScalarValue};
26-
use datafusion::functions::math::round::RoundFunc;
27-
use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
2825
use datafusion::physical_plan::ColumnarValue;
2926
use std::{cmp::min, sync::Arc};
3027

@@ -110,8 +107,6 @@ pub fn spark_round(
110107
let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else {
111108
return internal_err!("Invalid point argument for Round(): {:#?}", point);
112109
};
113-
// DataFusion's RoundFunc expects Int32 for decimal_places
114-
let point_i32 = ColumnarValue::Scalar(ScalarValue::Int32(Some(*point as i32)));
115110
match value {
116111
ColumnarValue::Array(array) => match array.data_type() {
117112
DataType::Int64 if *point < 0 => {
@@ -131,18 +126,9 @@ pub fn spark_round(
131126
let (precision, scale) = get_precision_scale(data_type);
132127
make_decimal_array(array, precision, scale, &f)
133128
}
134-
DataType::Float32 | DataType::Float64 => {
135-
let round_udf = RoundFunc::new();
136-
let return_field = Arc::new(Field::new("round", array.data_type().clone(), true));
137-
let args_for_round = ScalarFunctionArgs {
138-
args: vec![ColumnarValue::Array(Arc::clone(array)), point_i32.clone()],
139-
number_rows: array.len(),
140-
return_field,
141-
arg_fields: vec![],
142-
config_options: Arc::new(ConfigOptions::default()),
143-
};
144-
round_udf.invoke_with_args(args_for_round)
145-
}
129+
// Float32 / Float64 are routed to a JVM UDF (RoundFloatUDF / RoundDoubleUDF) by the
130+
// serde, because matching Spark's BigDecimal-via-Double.toString rounding from native
131+
// code does not stay consistent across JDK versions.
146132
dt => exec_err!("Not supported datatype for ROUND: {dt}"),
147133
},
148134
ColumnarValue::Scalar(a) => match a {
@@ -163,19 +149,6 @@ pub fn spark_round(
163149
let (precision, scale) = get_precision_scale(data_type);
164150
make_decimal_scalar(a, precision, scale, &f)
165151
}
166-
ScalarValue::Float32(_) | ScalarValue::Float64(_) => {
167-
let round_udf = RoundFunc::new();
168-
let data_type = a.data_type();
169-
let return_field = Arc::new(Field::new("round", data_type, true));
170-
let args_for_round = ScalarFunctionArgs {
171-
args: vec![ColumnarValue::Scalar(a.clone()), point_i32.clone()],
172-
number_rows: 1,
173-
return_field,
174-
arg_fields: vec![],
175-
config_options: Arc::new(ConfigOptions::default()),
176-
};
177-
round_udf.invoke_with_args(args_for_round)
178-
}
179152
dt => exec_err!("Not supported datatype for ROUND: {dt}"),
180153
},
181154
}
@@ -207,77 +180,92 @@ mod test {
207180

208181
use crate::spark_round;
209182

210-
use arrow::array::{Float32Array, Float64Array};
183+
use arrow::array::Decimal128Array;
211184
use arrow::datatypes::DataType;
212-
use datafusion::common::cast::{as_float32_array, as_float64_array};
213185
use datafusion::common::{Result, ScalarValue};
214186
use datafusion::physical_plan::ColumnarValue;
215187

216188
#[test]
217189
#[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
218-
fn test_round_f32_array() -> Result<()> {
190+
fn test_round_decimal128_array_pos_point() -> Result<()> {
191+
// Decimal128(10, 4) values: 125.2345, 15.3455, 0.1234, 0.1250, 0.7850, 123.1230
192+
let input = Decimal128Array::from(vec![1252345, 153455, 1234, 1250, 7850, 1231230])
193+
.with_precision_and_scale(10, 4)?;
219194
let args = vec![
220-
ColumnarValue::Array(Arc::new(Float32Array::from(vec![
221-
125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
222-
]))),
195+
ColumnarValue::Array(Arc::new(input)),
223196
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
224197
];
225-
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32, false)? else {
198+
let return_type = DataType::Decimal128(8, 2);
199+
let ColumnarValue::Array(result) = spark_round(&args, &return_type, false)? else {
226200
unreachable!()
227201
};
228-
let floats = as_float32_array(&result)?;
229-
let expected = Float32Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
230-
assert_eq!(floats, &expected);
202+
// HALF_UP: 0.125 -> 0.13, 0.785 -> 0.79
203+
let expected = Decimal128Array::from(vec![12523, 1535, 12, 13, 79, 12312])
204+
.with_precision_and_scale(8, 2)?;
205+
let actual = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
206+
assert_eq!(actual, &expected);
231207
Ok(())
232208
}
233209

234210
#[test]
235211
#[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
236-
fn test_round_f64_array() -> Result<()> {
212+
fn test_round_decimal128_array_neg_point() -> Result<()> {
213+
// Decimal128(10, 4) values: 125.2345, -125.2345, 150.0000, -150.0000, 0.0000
214+
let input = Decimal128Array::from(vec![1252345, -1252345, 1500000, -1500000, 0])
215+
.with_precision_and_scale(10, 4)?;
237216
let args = vec![
238-
ColumnarValue::Array(Arc::new(Float64Array::from(vec![
239-
125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
240-
]))),
241-
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
217+
ColumnarValue::Array(Arc::new(input)),
218+
ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))),
242219
];
243-
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64, false)? else {
220+
let return_type = DataType::Decimal128(6, 0);
221+
let ColumnarValue::Array(result) = spark_round(&args, &return_type, false)? else {
244222
unreachable!()
245223
};
246-
let floats = as_float64_array(&result)?;
247-
let expected = Float64Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
248-
assert_eq!(floats, &expected);
224+
// HALF_UP: 125.2345 rounds DOWN to 100, 150 ties round AWAY from zero to 200
225+
let expected = Decimal128Array::from(vec![100, -100, 200, -200, 0])
226+
.with_precision_and_scale(6, 0)?;
227+
let actual = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
228+
assert_eq!(actual, &expected);
249229
Ok(())
250230
}
251231

252232
#[test]
253233
#[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
254-
fn test_round_f32_scalar() -> Result<()> {
234+
fn test_round_decimal128_scalar_pos_point() -> Result<()> {
235+
// 125.2345, point=2 -> 125.23
255236
let args = vec![
256-
ColumnarValue::Scalar(ScalarValue::Float32(Some(125.2345))),
237+
ColumnarValue::Scalar(ScalarValue::Decimal128(Some(1252345), 10, 4)),
257238
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
258239
];
259-
let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) =
260-
spark_round(&args, &DataType::Float32, false)?
240+
let return_type = DataType::Decimal128(8, 2);
241+
let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), p, s)) =
242+
spark_round(&args, &return_type, false)?
261243
else {
262244
unreachable!()
263245
};
264-
assert_eq!(result, 125.23);
246+
assert_eq!(result, 12523);
247+
assert_eq!(p, 8);
248+
assert_eq!(s, 2);
265249
Ok(())
266250
}
267251

268252
#[test]
269253
#[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
270-
fn test_round_f64_scalar() -> Result<()> {
254+
fn test_round_decimal128_scalar_neg_point() -> Result<()> {
255+
// 150.0000, point=-2 -> 200 (HALF_UP rounds the .5 tie away from zero)
271256
let args = vec![
272-
ColumnarValue::Scalar(ScalarValue::Float64(Some(125.2345))),
273-
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
257+
ColumnarValue::Scalar(ScalarValue::Decimal128(Some(1500000), 10, 4)),
258+
ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))),
274259
];
275-
let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) =
276-
spark_round(&args, &DataType::Float64, false)?
260+
let return_type = DataType::Decimal128(6, 0);
261+
let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), p, s)) =
262+
spark_round(&args, &return_type, false)?
277263
else {
278264
unreachable!()
279265
};
280-
assert_eq!(result, 125.23);
266+
assert_eq!(result, 200);
267+
assert_eq!(p, 6);
268+
assert_eq!(s, 0);
281269
Ok(())
282270
}
283271
}

0 commit comments

Comments
 (0)