Skip to content

Commit 548fd18

Browse files
committed
spark_compatible_ceil_function
1 parent 83f7a3b commit 548fd18

File tree

1 file changed

+120
-2
lines changed
  • datafusion/spark/src/function/math

1 file changed

+120
-2
lines changed

datafusion/spark/src/function/math/ceil.rs

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use arrow::array::cast::AsArray;
19+
use arrow::array::types::Decimal128Type;
1820
use arrow::array::{
19-
Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array,
21+
ArrayRef, Decimal128Array, Float32Array, Float64Array, Int8Array, Int16Array,
22+
Int32Array, Int64Array,
2023
};
2124
use arrow::compute::kernels::arity::unary;
2225
use arrow::datatypes::DataType;
@@ -26,7 +29,6 @@ use datafusion_expr::{
2629
};
2730
use std::any::Any;
2831
use std::sync::Arc;
29-
// spark semantics
3032

3133
macro_rules! downcast_compute_op {
3234
($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{
@@ -78,6 +80,10 @@ pub fn spark_ceil(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErr
7880
// Optimization: Int64 -> Int64 doesn't need conversion, just return same array
7981
Ok(ColumnarValue::Array(Arc::clone(array)))
8082
}
83+
DataType::Decimal128(precision, scale) if *scale > 0 => {
84+
let f = decimal_ceil_f(*scale);
85+
make_decimal_array(array, *precision, *scale, &f)
86+
}
8187
other => Err(DataFusionError::Internal(format!(
8288
"Unsupported data type {other:?} for function ceil",
8389
))),
@@ -99,6 +105,13 @@ pub fn spark_ceil(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErr
99105
a.map(|x| x as i64),
100106
))),
101107
ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(*a))),
108+
ScalarValue::Decimal128(a, precision, scale) if *scale > 0 => {
109+
let f = decimal_ceil_f(*scale);
110+
let result = a.map(f);
111+
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
112+
result, *precision, *scale,
113+
)))
114+
}
102115
_ => Err(DataFusionError::Internal(format!(
103116
"Unsupported data type {:?} for function ceil",
104117
value.data_type(),
@@ -107,6 +120,33 @@ pub fn spark_ceil(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionErr
107120
}
108121
}
109122

123+
/// Computes ceil for a Decimal128 array, preserving the scale.
124+
/// Divides by 10^scale, takes ceiling, then multiplies back.
125+
#[inline]
126+
fn make_decimal_array(
127+
array: &ArrayRef,
128+
precision: u8,
129+
scale: i8,
130+
f: &dyn Fn(i128) -> i128,
131+
) -> Result<ColumnarValue, DataFusionError> {
132+
let array = array.as_primitive::<Decimal128Type>();
133+
let result: Decimal128Array = unary(array, f);
134+
let result = result.with_data_type(DataType::Decimal128(precision, scale));
135+
Ok(ColumnarValue::Array(Arc::new(result)))
136+
}
137+
138+
/// Returns a closure that computes ceil for decimal values.
139+
#[inline]
140+
fn decimal_ceil_f(scale: i8) -> impl Fn(i128) -> i128 {
141+
let div = 10_i128.pow(scale as u32);
142+
move |x: i128| {
143+
let d = x / div;
144+
let r = x % div;
145+
// Ceiling: round up for positive remainders
146+
(if r > 0 { d + 1 } else { d }) * div
147+
}
148+
}
149+
110150
#[derive(Debug, PartialEq, Eq, Hash)]
111151
pub struct SparkCiel {
112152
signature: Signature,
@@ -153,3 +193,81 @@ impl ScalarUDFImpl for SparkCiel {
153193
spark_ceil(&args.args)
154194
}
155195
}
196+
197+
#[cfg(test)]
198+
mod tests {
199+
use super::*;
200+
use arrow::array::Decimal128Array;
201+
use datafusion_common::Result;
202+
use datafusion_common::cast::as_decimal128_array;
203+
204+
#[test]
205+
fn test_ceil_decimal128_array() -> Result<()> {
206+
let array = Decimal128Array::from(vec![
207+
Some(12345), // 123.45
208+
Some(12500), // 125.00
209+
Some(-12999), // -129.99
210+
None,
211+
])
212+
.with_precision_and_scale(5, 2)?;
213+
let args = vec![ColumnarValue::Array(Arc::new(array))];
214+
let ColumnarValue::Array(result) = spark_ceil(&args)? else {
215+
unreachable!()
216+
};
217+
let expected = Decimal128Array::from(vec![
218+
Some(12400), // 124.00
219+
Some(12500), // 125.00
220+
Some(-12900), // -129.00
221+
None,
222+
])
223+
.with_precision_and_scale(5, 2)?;
224+
let actual = as_decimal128_array(&result)?;
225+
assert_eq!(actual, &expected);
226+
Ok(())
227+
}
228+
229+
#[test]
230+
fn test_ceil_decimal128_scalar() -> Result<()> {
231+
let args = vec![ColumnarValue::Scalar(ScalarValue::Decimal128(
232+
Some(567),
233+
3,
234+
1,
235+
))]; // 56.7
236+
let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), 3, 1)) =
237+
spark_ceil(&args)?
238+
else {
239+
unreachable!()
240+
};
241+
assert_eq!(result, 570); // 57.0
242+
Ok(())
243+
}
244+
245+
#[test]
246+
fn test_ceil_decimal128_negative_scalar() -> Result<()> {
247+
// -56.7 should ceil to -56.0
248+
let args = vec![ColumnarValue::Scalar(ScalarValue::Decimal128(
249+
Some(-567),
250+
3,
251+
1,
252+
))];
253+
let ColumnarValue::Scalar(ScalarValue::Decimal128(Some(result), 3, 1)) =
254+
spark_ceil(&args)?
255+
else {
256+
unreachable!()
257+
};
258+
assert_eq!(result, -560); // -56.0
259+
Ok(())
260+
}
261+
262+
#[test]
263+
fn test_ceil_decimal128_null_scalar() -> Result<()> {
264+
let args = vec![ColumnarValue::Scalar(ScalarValue::Decimal128(None, 5, 2))];
265+
let ColumnarValue::Scalar(ScalarValue::Decimal128(result, 5, 2)) =
266+
spark_ceil(&args)?
267+
else {
268+
unreachable!()
269+
};
270+
assert_eq!(result, None);
271+
Ok(())
272+
}
273+
}

0 commit comments

Comments
 (0)