Skip to content

Commit 83f7a3b

Browse files
committed
port_ceil_comet_to_df
1 parent 9660c98 commit 83f7a3b

File tree

3 files changed

+367
-2
lines changed

3 files changed

+367
-2
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{
19+
Float32Array, Float64Array, Int8Array, Int16Array, Int32Array, Int64Array,
20+
};
21+
use arrow::compute::kernels::arity::unary;
22+
use arrow::datatypes::DataType;
23+
use datafusion_common::{DataFusionError, ScalarValue};
24+
use datafusion_expr::{
25+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26+
};
27+
use std::any::Any;
28+
use std::sync::Arc;
29+
// spark semantics
30+
31+
macro_rules! downcast_compute_op {
32+
($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{
33+
let n = $ARRAY.as_any().downcast_ref::<$TYPE>();
34+
match n {
35+
Some(array) => {
36+
let res: $RESULT =
37+
arrow::compute::kernels::arity::unary(array, |x| x.$FUNC() as i64);
38+
Ok(Arc::new(res))
39+
}
40+
_ => Err(DataFusionError::Internal(format!(
41+
"Invalid data type for {}",
42+
$NAME
43+
))),
44+
}
45+
}};
46+
}
47+
48+
pub fn spark_ceil(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
49+
let value = &args[0];
50+
match value {
51+
ColumnarValue::Array(array) => match array.data_type() {
52+
DataType::Float32 => {
53+
let result =
54+
downcast_compute_op!(array, "ceil", ceil, Float32Array, Int64Array);
55+
Ok(ColumnarValue::Array(result?))
56+
}
57+
DataType::Float64 => {
58+
let result =
59+
downcast_compute_op!(array, "ceil", ceil, Float64Array, Int64Array);
60+
Ok(ColumnarValue::Array(result?))
61+
}
62+
DataType::Int8 => {
63+
let input = array.as_any().downcast_ref::<Int8Array>().unwrap();
64+
let result: Int64Array = unary(input, |x| x as i64);
65+
Ok(ColumnarValue::Array(Arc::new(result)))
66+
}
67+
DataType::Int16 => {
68+
let input = array.as_any().downcast_ref::<Int16Array>().unwrap();
69+
let result: Int64Array = unary(input, |x| x as i64);
70+
Ok(ColumnarValue::Array(Arc::new(result)))
71+
}
72+
DataType::Int32 => {
73+
let input = array.as_any().downcast_ref::<Int32Array>().unwrap();
74+
let result: Int64Array = unary(input, |x| x as i64);
75+
Ok(ColumnarValue::Array(Arc::new(result)))
76+
}
77+
DataType::Int64 => {
78+
// Optimization: Int64 -> Int64 doesn't need conversion, just return same array
79+
Ok(ColumnarValue::Array(Arc::clone(array)))
80+
}
81+
other => Err(DataFusionError::Internal(format!(
82+
"Unsupported data type {other:?} for function ceil",
83+
))),
84+
},
85+
ColumnarValue::Scalar(a) => match a {
86+
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
87+
a.map(|x| x.ceil() as i64),
88+
))),
89+
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
90+
a.map(|x| x.ceil() as i64),
91+
))),
92+
ScalarValue::Int8(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
93+
a.map(|x| x as i64),
94+
))),
95+
ScalarValue::Int16(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
96+
a.map(|x| x as i64),
97+
))),
98+
ScalarValue::Int32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
99+
a.map(|x| x as i64),
100+
))),
101+
ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(*a))),
102+
_ => Err(DataFusionError::Internal(format!(
103+
"Unsupported data type {:?} for function ceil",
104+
value.data_type(),
105+
))),
106+
},
107+
}
108+
}
109+
110+
#[derive(Debug, PartialEq, Eq, Hash)]
111+
pub struct SparkCiel {
112+
signature: Signature,
113+
}
114+
115+
impl Default for SparkCiel {
116+
fn default() -> Self {
117+
Self::new()
118+
}
119+
}
120+
121+
impl SparkCiel {
122+
pub fn new() -> Self {
123+
Self {
124+
signature: Signature::numeric(1, Volatility::Immutable),
125+
}
126+
}
127+
}
128+
129+
impl ScalarUDFImpl for SparkCiel {
130+
fn as_any(&self) -> &dyn Any {
131+
self
132+
}
133+
134+
fn name(&self) -> &str {
135+
"ceil"
136+
}
137+
138+
fn signature(&self) -> &Signature {
139+
&self.signature
140+
}
141+
142+
fn return_type(
143+
&self,
144+
_arg_types: &[DataType],
145+
) -> datafusion_common::Result<DataType> {
146+
Ok(DataType::Int64)
147+
}
148+
149+
fn invoke_with_args(
150+
&self,
151+
args: ScalarFunctionArgs,
152+
) -> datafusion_common::Result<ColumnarValue> {
153+
spark_ceil(&args.args)
154+
}
155+
}

datafusion/spark/src/function/math/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
pub mod abs;
19+
mod ceil;
1920
pub mod expm1;
2021
pub mod factorial;
2122
pub mod hex;
@@ -42,6 +43,7 @@ make_udf_function!(width_bucket::SparkWidthBucket, width_bucket);
4243
make_udf_function!(trigonometry::SparkCsc, csc);
4344
make_udf_function!(trigonometry::SparkSec, sec);
4445
make_udf_function!(negative::SparkNegative, negative);
46+
make_udf_function!(ceil::SparkCiel, ceil);
4547

4648
pub mod expr_fn {
4749
use datafusion_functions::export_functions;
@@ -70,6 +72,7 @@ pub mod expr_fn {
7072
"Returns the negation of expr (unary minus).",
7173
arg1
7274
));
75+
export_functions!((ceil, "Returns the ceiling of expr.", arg1));
7376
}
7477

7578
pub fn functions() -> Vec<Arc<ScalarUDF>> {
@@ -86,5 +89,6 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
8689
csc(),
8790
sec(),
8891
negative(),
92+
ceil(),
8993
]
9094
}

0 commit comments

Comments
 (0)