@@ -19,12 +19,9 @@ use crate::arithmetic_overflow_error;
1919use crate :: math_funcs:: utils:: { get_precision_scale, make_decimal_array, make_decimal_scalar} ;
2020use arrow:: array:: { Array , ArrowNativeTypeOp } ;
2121use arrow:: array:: { Int16Array , Int32Array , Int64Array , Int8Array } ;
22- use arrow:: datatypes:: { DataType , Field } ;
22+ use arrow:: datatypes:: DataType ;
2323use arrow:: error:: ArrowError ;
24- use datafusion:: common:: config:: ConfigOptions ;
2524use datafusion:: common:: { exec_err, internal_err, DataFusionError , ScalarValue } ;
26- use datafusion:: functions:: math:: round:: RoundFunc ;
27- use datafusion:: logical_expr:: { ScalarFunctionArgs , ScalarUDFImpl } ;
2825use datafusion:: physical_plan:: ColumnarValue ;
2926use std:: { cmp:: min, sync:: Arc } ;
3027
@@ -110,8 +107,6 @@ pub fn spark_round(
110107 let ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( point) ) ) = point else {
111108 return internal_err ! ( "Invalid point argument for Round(): {:#?}" , point) ;
112109 } ;
113- // DataFusion's RoundFunc expects Int32 for decimal_places
114- let point_i32 = ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( * point as i32 ) ) ) ;
115110 match value {
116111 ColumnarValue :: Array ( array) => match array. data_type ( ) {
117112 DataType :: Int64 if * point < 0 => {
@@ -131,18 +126,9 @@ pub fn spark_round(
131126 let ( precision, scale) = get_precision_scale ( data_type) ;
132127 make_decimal_array ( array, precision, scale, & f)
133128 }
134- DataType :: Float32 | DataType :: Float64 => {
135- let round_udf = RoundFunc :: new ( ) ;
136- let return_field = Arc :: new ( Field :: new ( "round" , array. data_type ( ) . clone ( ) , true ) ) ;
137- let args_for_round = ScalarFunctionArgs {
138- args : vec ! [ ColumnarValue :: Array ( Arc :: clone( array) ) , point_i32. clone( ) ] ,
139- number_rows : array. len ( ) ,
140- return_field,
141- arg_fields : vec ! [ ] ,
142- config_options : Arc :: new ( ConfigOptions :: default ( ) ) ,
143- } ;
144- round_udf. invoke_with_args ( args_for_round)
145- }
129+ // Float32 / Float64 are routed to a JVM UDF (RoundFloatUDF / RoundDoubleUDF) by the
130+ // serde, because matching Spark's BigDecimal-via-Double.toString rounding from native
131+ // code does not stay consistent across JDK versions.
146132 dt => exec_err ! ( "Not supported datatype for ROUND: {dt}" ) ,
147133 } ,
148134 ColumnarValue :: Scalar ( a) => match a {
@@ -163,19 +149,6 @@ pub fn spark_round(
163149 let ( precision, scale) = get_precision_scale ( data_type) ;
164150 make_decimal_scalar ( a, precision, scale, & f)
165151 }
166- ScalarValue :: Float32 ( _) | ScalarValue :: Float64 ( _) => {
167- let round_udf = RoundFunc :: new ( ) ;
168- let data_type = a. data_type ( ) ;
169- let return_field = Arc :: new ( Field :: new ( "round" , data_type, true ) ) ;
170- let args_for_round = ScalarFunctionArgs {
171- args : vec ! [ ColumnarValue :: Scalar ( a. clone( ) ) , point_i32. clone( ) ] ,
172- number_rows : 1 ,
173- return_field,
174- arg_fields : vec ! [ ] ,
175- config_options : Arc :: new ( ConfigOptions :: default ( ) ) ,
176- } ;
177- round_udf. invoke_with_args ( args_for_round)
178- }
179152 dt => exec_err ! ( "Not supported datatype for ROUND: {dt}" ) ,
180153 } ,
181154 }
@@ -207,77 +180,92 @@ mod test {
207180
208181 use crate :: spark_round;
209182
210- use arrow:: array:: { Float32Array , Float64Array } ;
183+ use arrow:: array:: Decimal128Array ;
211184 use arrow:: datatypes:: DataType ;
212- use datafusion:: common:: cast:: { as_float32_array, as_float64_array} ;
213185 use datafusion:: common:: { Result , ScalarValue } ;
214186 use datafusion:: physical_plan:: ColumnarValue ;
215187
216188 #[ test]
217189 #[ cfg_attr( miri, ignore) ] // rounding does not work when miri enabled
218- fn test_round_f32_array ( ) -> Result < ( ) > {
190+ fn test_round_decimal128_array_pos_point ( ) -> Result < ( ) > {
191+ // Decimal128(10, 4) values: 125.2345, 15.3455, 0.1234, 0.1250, 0.7850, 123.1230
192+ let input = Decimal128Array :: from ( vec ! [ 1252345 , 153455 , 1234 , 1250 , 7850 , 1231230 ] )
193+ . with_precision_and_scale ( 10 , 4 ) ?;
219194 let args = vec ! [
220- ColumnarValue :: Array ( Arc :: new( Float32Array :: from( vec![
221- 125.2345 , 15.3455 , 0.1234 , 0.125 , 0.785 , 123.123 ,
222- ] ) ) ) ,
195+ ColumnarValue :: Array ( Arc :: new( input) ) ,
223196 ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
224197 ] ;
225- let ColumnarValue :: Array ( result) = spark_round ( & args, & DataType :: Float32 , false ) ? else {
198+ let return_type = DataType :: Decimal128 ( 8 , 2 ) ;
199+ let ColumnarValue :: Array ( result) = spark_round ( & args, & return_type, false ) ? else {
226200 unreachable ! ( )
227201 } ;
228- let floats = as_float32_array ( & result) ?;
229- let expected = Float32Array :: from ( vec ! [ 125.23 , 15.35 , 0.12 , 0.13 , 0.79 , 123.12 ] ) ;
230- assert_eq ! ( floats, & expected) ;
202+ // HALF_UP: 0.125 -> 0.13, 0.785 -> 0.79
203+ let expected = Decimal128Array :: from ( vec ! [ 12523 , 1535 , 12 , 13 , 79 , 12312 ] )
204+ . with_precision_and_scale ( 8 , 2 ) ?;
205+ let actual = result. as_any ( ) . downcast_ref :: < Decimal128Array > ( ) . unwrap ( ) ;
206+ assert_eq ! ( actual, & expected) ;
231207 Ok ( ( ) )
232208 }
233209
234210 #[ test]
235211 #[ cfg_attr( miri, ignore) ] // rounding does not work when miri enabled
236- fn test_round_f64_array ( ) -> Result < ( ) > {
212+ fn test_round_decimal128_array_neg_point ( ) -> Result < ( ) > {
213+ // Decimal128(10, 4) values: 125.2345, -125.2345, 150.0000, -150.0000, 0.0000
214+ let input = Decimal128Array :: from ( vec ! [ 1252345 , -1252345 , 1500000 , -1500000 , 0 ] )
215+ . with_precision_and_scale ( 10 , 4 ) ?;
237216 let args = vec ! [
238- ColumnarValue :: Array ( Arc :: new( Float64Array :: from( vec![
239- 125.2345 , 15.3455 , 0.1234 , 0.125 , 0.785 , 123.123 ,
240- ] ) ) ) ,
241- ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
217+ ColumnarValue :: Array ( Arc :: new( input) ) ,
218+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( -2 ) ) ) ,
242219 ] ;
243- let ColumnarValue :: Array ( result) = spark_round ( & args, & DataType :: Float64 , false ) ? else {
220+ let return_type = DataType :: Decimal128 ( 6 , 0 ) ;
221+ let ColumnarValue :: Array ( result) = spark_round ( & args, & return_type, false ) ? else {
244222 unreachable ! ( )
245223 } ;
246- let floats = as_float64_array ( & result) ?;
247- let expected = Float64Array :: from ( vec ! [ 125.23 , 15.35 , 0.12 , 0.13 , 0.79 , 123.12 ] ) ;
248- assert_eq ! ( floats, & expected) ;
224+ // HALF_UP: 125.2345 rounds DOWN to 100, 150 ties round AWAY from zero to 200
225+ let expected = Decimal128Array :: from ( vec ! [ 100 , -100 , 200 , -200 , 0 ] )
226+ . with_precision_and_scale ( 6 , 0 ) ?;
227+ let actual = result. as_any ( ) . downcast_ref :: < Decimal128Array > ( ) . unwrap ( ) ;
228+ assert_eq ! ( actual, & expected) ;
249229 Ok ( ( ) )
250230 }
251231
252232 #[ test]
253233 #[ cfg_attr( miri, ignore) ] // rounding does not work when miri enabled
254- fn test_round_f32_scalar ( ) -> Result < ( ) > {
234+ fn test_round_decimal128_scalar_pos_point ( ) -> Result < ( ) > {
235+ // 125.2345, point=2 -> 125.23
255236 let args = vec ! [
256- ColumnarValue :: Scalar ( ScalarValue :: Float32 ( Some ( 125.2345 ) ) ) ,
237+ ColumnarValue :: Scalar ( ScalarValue :: Decimal128 ( Some ( 1252345 ) , 10 , 4 ) ) ,
257238 ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
258239 ] ;
259- let ColumnarValue :: Scalar ( ScalarValue :: Float32 ( Some ( result) ) ) =
260- spark_round ( & args, & DataType :: Float32 , false ) ?
240+ let return_type = DataType :: Decimal128 ( 8 , 2 ) ;
241+ let ColumnarValue :: Scalar ( ScalarValue :: Decimal128 ( Some ( result) , p, s) ) =
242+ spark_round ( & args, & return_type, false ) ?
261243 else {
262244 unreachable ! ( )
263245 } ;
264- assert_eq ! ( result, 125.23 ) ;
246+ assert_eq ! ( result, 12523 ) ;
247+ assert_eq ! ( p, 8 ) ;
248+ assert_eq ! ( s, 2 ) ;
265249 Ok ( ( ) )
266250 }
267251
268252 #[ test]
269253 #[ cfg_attr( miri, ignore) ] // rounding does not work when miri enabled
270- fn test_round_f64_scalar ( ) -> Result < ( ) > {
254+ fn test_round_decimal128_scalar_neg_point ( ) -> Result < ( ) > {
255+ // 150.0000, point=-2 -> 200 (HALF_UP rounds the .5 tie away from zero)
271256 let args = vec ! [
272- ColumnarValue :: Scalar ( ScalarValue :: Float64 ( Some ( 125.2345 ) ) ) ,
273- ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 2 ) ) ) ,
257+ ColumnarValue :: Scalar ( ScalarValue :: Decimal128 ( Some ( 1500000 ) , 10 , 4 ) ) ,
258+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( - 2 ) ) ) ,
274259 ] ;
275- let ColumnarValue :: Scalar ( ScalarValue :: Float64 ( Some ( result) ) ) =
276- spark_round ( & args, & DataType :: Float64 , false ) ?
260+ let return_type = DataType :: Decimal128 ( 6 , 0 ) ;
261+ let ColumnarValue :: Scalar ( ScalarValue :: Decimal128 ( Some ( result) , p, s) ) =
262+ spark_round ( & args, & return_type, false ) ?
277263 else {
278264 unreachable ! ( )
279265 } ;
280- assert_eq ! ( result, 125.23 ) ;
266+ assert_eq ! ( result, 200 ) ;
267+ assert_eq ! ( p, 6 ) ;
268+ assert_eq ! ( s, 0 ) ;
281269 Ok ( ( ) )
282270 }
283271}
0 commit comments