Skip to content

Commit b69469b

Browse files
address review comments
1 parent 3c2af76 commit b69469b

1 file changed

Lines changed: 61 additions & 49 deletions

File tree

  • native/spark-expr/src/math_funcs

native/spark-expr/src/math_funcs/round.rs

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@ use crate::arithmetic_overflow_error;
1919
use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar};
2020
use arrow::array::{Array, ArrowNativeTypeOp};
2121
use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array};
22-
use arrow::datatypes::DataType;
22+
use arrow::datatypes::{DataType, Field};
2323
use arrow::error::ArrowError;
24+
use datafusion::common::config::ConfigOptions;
2425
use datafusion::common::{exec_err, internal_err, DataFusionError, ScalarValue};
26+
use datafusion::functions::math::round::RoundFunc;
27+
use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl};
2528
use datafusion::physical_plan::ColumnarValue;
2629
use std::{cmp::min, sync::Arc};
2730

@@ -107,6 +110,8 @@ pub fn spark_round(
107110
let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else {
108111
return internal_err!("Invalid point argument for Round(): {:#?}", point);
109112
};
113+
// DataFusion's RoundFunc expects Int32 for decimal_places
114+
let point_i32 = ColumnarValue::Scalar(ScalarValue::Int32(Some(*point as i32)));
110115
match value {
111116
ColumnarValue::Array(array) => match array.data_type() {
112117
DataType::Int64 if *point < 0 => {
@@ -126,9 +131,18 @@ pub fn spark_round(
126131
let (precision, scale) = get_precision_scale(data_type);
127132
make_decimal_array(array, precision, scale, &f)
128133
}
129-
// Float32 / Float64 are routed to a JVM UDF (RoundFloatUDF / RoundDoubleUDF) by the
130-
// serde, because matching Spark's BigDecimal-via-Double.toString rounding from native
131-
// code does not stay consistent across JDK versions.
134+
DataType::Float32 | DataType::Float64 => {
135+
let round_udf = RoundFunc::new();
136+
let return_field = Arc::new(Field::new("round", array.data_type().clone(), true));
137+
let args_for_round = ScalarFunctionArgs {
138+
args: vec![ColumnarValue::Array(Arc::clone(array)), point_i32.clone()],
139+
number_rows: array.len(),
140+
return_field,
141+
arg_fields: vec![],
142+
config_options: Arc::new(ConfigOptions::default()),
143+
};
144+
round_udf.invoke_with_args(args_for_round)
145+
}
132146
dt => exec_err!("Not supported datatype for ROUND: {dt}"),
133147
},
134148
ColumnarValue::Scalar(a) => match a {
@@ -149,6 +163,19 @@ pub fn spark_round(
149163
let (precision, scale) = get_precision_scale(data_type);
150164
make_decimal_scalar(a, precision, scale, &f)
151165
}
166+
ScalarValue::Float32(_) | ScalarValue::Float64(_) => {
167+
let round_udf = RoundFunc::new();
168+
let data_type = a.data_type();
169+
let return_field = Arc::new(Field::new("round", data_type, true));
170+
let args_for_round = ScalarFunctionArgs {
171+
args: vec![ColumnarValue::Scalar(a.clone()), point_i32.clone()],
172+
number_rows: 1,
173+
return_field,
174+
arg_fields: vec![],
175+
config_options: Arc::new(ConfigOptions::default()),
176+
};
177+
round_udf.invoke_with_args(args_for_round)
178+
}
152179
dt => exec_err!("Not supported datatype for ROUND: {dt}"),
153180
},
154181
}
@@ -180,92 +207,77 @@ mod test {
180207

181208
use crate::spark_round;
182209

183-
use arrow::array::Decimal128Array;
210+
use arrow::array::{Float32Array, Float64Array};
184211
use arrow::datatypes::DataType;
212+
use datafusion::common::cast::{as_float32_array, as_float64_array};
185213
use datafusion::common::{Result, ScalarValue};
186214
use datafusion::physical_plan::ColumnarValue;
187215

188216
#[test]
189217
#[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
190-
fn test_round_decimal128_array_pos_point() -> Result<()> {
191-
// Decimal128(10, 4) values: 125.2345, 15.3455, 0.1234, 0.1250, 0.7850, 123.1230
192-
let input = Decimal128Array::from(vec![1252345, 153455, 1234, 1250, 7850, 1231230])
193-
.with_precision_and_scale(10, 4)?;
218+
fn test_round_f32_array() -> Result<()> {
194219
let args = vec![
195-
ColumnarValue::Array(Arc::new(input)),
220+
ColumnarValue::Array(Arc::new(Float32Array::from(vec![
221+
125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
222+
]))),
196223
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
197224
];
198-
let return_type = DataType::Decimal128(8, 2);
199-
let ColumnarValue::Array(result) = spark_round(&args, &return_type, false)? else {
225+
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float32, false)? else {
200226
unreachable!()
201227
};
202-
// HALF_UP: 0.125 -> 0.13, 0.785 -> 0.79
203-
let expected = Decimal128Array::from(vec![12523, 1535, 12, 13, 79, 12312])
204-
.with_precision_and_scale(8, 2)?;
205-
let actual = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
206-
assert_eq!(actual, &expected);
228+
let floats = as_float32_array(&result)?;
229+
let expected = Float32Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
230+
assert_eq!(floats, &expected);
207231
Ok(())
208232
}
209233

210234
#[test]
211235
#[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
212-
fn test_round_decimal128_array_neg_point() -> Result<()> {
213-
// Decimal128(10, 4) values: 125.2345, -125.2345, 150.0000, -150.0000, 0.0000
214-
let input = Decimal128Array::from(vec![1252345, -1252345, 1500000, -1500000, 0])
215-
.with_precision_and_scale(10, 4)?;
236+
fn test_round_f64_array() -> Result<()> {
216237
let args = vec![
217-
ColumnarValue::Array(Arc::new(input)),
218-
ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))),
238+
ColumnarValue::Array(Arc::new(Float64Array::from(vec![
239+
125.2345, 15.3455, 0.1234, 0.125, 0.785, 123.123,
240+
]))),
241+
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
219242
];
220-
let return_type = DataType::Decimal128(6, 0);
221-
let ColumnarValue::Array(result) = spark_round(&args, &return_type, false)? else {
243+
let ColumnarValue::Array(result) = spark_round(&args, &DataType::Float64, false)? else {
222244
unreachable!()
223245
};
224-
// HALF_UP: 125.2345 rounds DOWN to 100, 150 ties round AWAY from zero to 200
225-
let expected =
226-
Decimal128Array::from(vec![100, -100, 200, -200, 0]).with_precision_and_scale(6, 0)?;
227-
let actual = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
228-
assert_eq!(actual, &expected);
246+
let floats = as_float64_array(&result)?;
247+
let expected = Float64Array::from(vec![125.23, 15.35, 0.12, 0.13, 0.79, 123.12]);
248+
assert_eq!(floats, &expected);
229249
Ok(())
230250
}
231251

232252
#[test]
233253
#[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
234-
fn test_round_decimal128_scalar_pos_point() -> Result<()> {
235-
// 125.2345, point=2 -> 125.23
254+
fn test_round_f32_scalar() -> Result<()> {
236255
let args = vec![
237-
ColumnarValue::Scalar(ScalarValue::Decimal128(Some(1252345), 10, 4)),
256+
ColumnarValue::Scalar(ScalarValue::Float32(Some(125.2345))),
238257
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
239258
];
240-
let return_type = DataType::Decimal128(8, 2);
241-
let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), p, s)) =
242-
spark_round(&args, &return_type, false)?
259+
let ColumnarValue::Scalar(ScalarValue::Float32(Some(result))) =
260+
spark_round(&args, &DataType::Float32, false)?
243261
else {
244262
unreachable!()
245263
};
246-
assert_eq!(result, 12523);
247-
assert_eq!(p, 8);
248-
assert_eq!(s, 2);
264+
assert_eq!(result, 125.23);
249265
Ok(())
250266
}
251267

252268
#[test]
253269
#[cfg_attr(miri, ignore)] // rounding does not work when miri enabled
254-
fn test_round_decimal128_scalar_neg_point() -> Result<()> {
255-
// 150.0000, point=-2 -> 200 (HALF_UP rounds the .5 tie away from zero)
270+
fn test_round_f64_scalar() -> Result<()> {
256271
let args = vec![
257-
ColumnarValue::Scalar(ScalarValue::Decimal128(Some(1500000), 10, 4)),
258-
ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))),
272+
ColumnarValue::Scalar(ScalarValue::Float64(Some(125.2345))),
273+
ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
259274
];
260-
let return_type = DataType::Decimal128(6, 0);
261-
let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), p, s)) =
262-
spark_round(&args, &return_type, false)?
275+
let ColumnarValue::Scalar(ScalarValue::Float64(Some(result))) =
276+
spark_round(&args, &DataType::Float64, false)?
263277
else {
264278
unreachable!()
265279
};
266-
assert_eq!(result, 200);
267-
assert_eq!(p, 6);
268-
assert_eq!(s, 0);
280+
assert_eq!(result, 125.23);
269281
Ok(())
270282
}
271283
}

0 commit comments

Comments
 (0)