Skip to content

Commit 9e71267

Browse files
committed
df_int_timestamp_cast
1 parent 9660c98 commit 9e71267

File tree

3 files changed

+448
-1
lines changed

3 files changed

+448
-1
lines changed
Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, ArrayRef, AsArray, TimestampMicrosecondBuilder};
19+
use arrow::datatypes::{
20+
ArrowPrimitiveType, DataType, Int8Type, Int16Type, Int32Type, Int64Type, TimeUnit,
21+
};
22+
use datafusion_common::{Result as DataFusionResult, ScalarValue, exec_err};
23+
use datafusion_expr::{
24+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
25+
};
26+
use std::any::Any;
27+
use std::sync::Arc;
28+
const MICROS_PER_SECOND: i64 = 1_000_000;
29+
30+
#[derive(Debug, PartialEq, Eq, Hash)]
31+
pub struct Cast {
32+
signature: Signature,
33+
}
34+
impl Default for Cast {
35+
fn default() -> Self {
36+
Self::new()
37+
}
38+
}
39+
40+
impl Cast {
41+
pub fn new() -> Self {
42+
Self {
43+
signature: Signature::any(1, Volatility::Immutable),
44+
}
45+
}
46+
}
47+
48+
fn cast_int_to_timestamp<T: ArrowPrimitiveType>(
49+
array: &ArrayRef,
50+
) -> DataFusionResult<ArrayRef>
51+
where
52+
T::Native: Into<i64>,
53+
{
54+
let arr = array.as_primitive::<T>();
55+
let mut builder = TimestampMicrosecondBuilder::with_capacity(arr.len());
56+
57+
for i in 0..arr.len() {
58+
if arr.is_null(i) {
59+
builder.append_null();
60+
} else {
61+
let micros = (arr.value(i).into()).saturating_mul(MICROS_PER_SECOND);
62+
builder.append_value(micros);
63+
}
64+
}
65+
66+
Ok(Arc::new(builder.finish()))
67+
}
68+
69+
impl ScalarUDFImpl for Cast {
70+
fn as_any(&self) -> &dyn Any {
71+
self
72+
}
73+
74+
fn name(&self) -> &str {
75+
"spark_cast"
76+
}
77+
78+
fn signature(&self) -> &Signature {
79+
&self.signature
80+
}
81+
82+
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
83+
// for now we will be supporting int -> timestamp and keep adding more spark-compatible spark
84+
match &arg_types[0] {
85+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
86+
Ok(DataType::Timestamp(TimeUnit::Microsecond, None))
87+
}
88+
_ => exec_err!("Unsupported cast from {:?}", arg_types[0]),
89+
}
90+
}
91+
92+
fn invoke_with_args(
93+
&self,
94+
args: ScalarFunctionArgs,
95+
) -> DataFusionResult<ColumnarValue> {
96+
let input = &args.args[0];
97+
match input {
98+
ColumnarValue::Array(array) => match array.data_type() {
99+
DataType::Int8 => {
100+
let result = cast_int_to_timestamp::<Int8Type>(array)?;
101+
Ok(ColumnarValue::Array(result))
102+
}
103+
DataType::Int16 => {
104+
let result = cast_int_to_timestamp::<Int16Type>(array)?;
105+
Ok(ColumnarValue::Array(result))
106+
}
107+
DataType::Int32 => {
108+
let result = cast_int_to_timestamp::<Int32Type>(array)?;
109+
Ok(ColumnarValue::Array(result))
110+
}
111+
DataType::Int64 => {
112+
let result = cast_int_to_timestamp::<Int64Type>(array)?;
113+
Ok(ColumnarValue::Array(result))
114+
}
115+
_ => exec_err!(
116+
"Unsupported cast from {:?} to timestamp",
117+
array.data_type()
118+
),
119+
},
120+
ColumnarValue::Scalar(scalar) => {
121+
// Handle scalar conversions
122+
match scalar {
123+
ScalarValue::Int8(None)
124+
| ScalarValue::Int16(None)
125+
| ScalarValue::Int32(None)
126+
| ScalarValue::Int64(None) => Ok(ColumnarValue::Scalar(
127+
ScalarValue::TimestampMicrosecond(None, None),
128+
)),
129+
ScalarValue::Int8(Some(v)) => {
130+
let micros = (*v as i64).saturating_mul(MICROS_PER_SECOND);
131+
Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
132+
Some(micros),
133+
None,
134+
)))
135+
}
136+
ScalarValue::Int16(Some(v)) => {
137+
let micros = (*v as i64).saturating_mul(MICROS_PER_SECOND);
138+
Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
139+
Some(micros),
140+
None,
141+
)))
142+
}
143+
ScalarValue::Int32(Some(v)) => {
144+
let micros = (*v as i64).saturating_mul(MICROS_PER_SECOND);
145+
Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
146+
Some(micros),
147+
None,
148+
)))
149+
}
150+
ScalarValue::Int64(Some(v)) => {
151+
let micros = (*v).saturating_mul(MICROS_PER_SECOND);
152+
Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
153+
Some(micros),
154+
None,
155+
)))
156+
}
157+
_ => exec_err!("Unsupported cast from {:?} to timestamp", scalar),
158+
}
159+
}
160+
}
161+
}
162+
}
163+
164+
#[cfg(test)]
165+
mod tests {
166+
use super::*;
167+
use arrow::array::{Int8Array, Int16Array, Int32Array, Int64Array};
168+
use arrow::datatypes::{Field, TimestampMicrosecondType};
169+
use datafusion_expr::ScalarFunctionArgs;
170+
171+
fn make_args(input: ColumnarValue) -> ScalarFunctionArgs {
172+
let return_field = Arc::new(Field::new(
173+
"result",
174+
DataType::Timestamp(TimeUnit::Microsecond, None),
175+
true,
176+
));
177+
ScalarFunctionArgs {
178+
args: vec![input],
179+
arg_fields: vec![],
180+
number_rows: 0,
181+
return_field,
182+
config_options: Arc::new(Default::default()),
183+
}
184+
}
185+
186+
fn assert_scalar_timestamp(result: ColumnarValue, expected: i64) {
187+
match result {
188+
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(val), None)) => {
189+
assert_eq!(val, expected);
190+
}
191+
_ => panic!("Expected scalar timestamp with value {expected}"),
192+
}
193+
}
194+
195+
fn assert_scalar_null(result: ColumnarValue) {
196+
assert!(matches!(
197+
result,
198+
ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(None, None))
199+
));
200+
}
201+
202+
#[test]
203+
fn test_cast_int8_array_to_timestamp() {
204+
let array: ArrayRef = Arc::new(Int8Array::from(vec![
205+
Some(0),
206+
Some(1),
207+
Some(-1),
208+
Some(127),
209+
Some(-128),
210+
None,
211+
]));
212+
213+
let cast = Cast::new();
214+
let args = make_args(ColumnarValue::Array(array));
215+
let result = cast.invoke_with_args(args).unwrap();
216+
217+
match result {
218+
ColumnarValue::Array(result_array) => {
219+
let ts_array = result_array.as_primitive::<TimestampMicrosecondType>();
220+
assert_eq!(ts_array.value(0), 0);
221+
assert_eq!(ts_array.value(1), 1_000_000);
222+
assert_eq!(ts_array.value(2), -1_000_000);
223+
assert_eq!(ts_array.value(3), 127_000_000);
224+
assert_eq!(ts_array.value(4), -128_000_000);
225+
assert!(ts_array.is_null(5));
226+
}
227+
_ => panic!("Expected array result"),
228+
}
229+
}
230+
231+
#[test]
232+
fn test_cast_int16_array_to_timestamp() {
233+
let array: ArrayRef = Arc::new(Int16Array::from(vec![
234+
Some(0),
235+
Some(32767),
236+
Some(-32768),
237+
None,
238+
]));
239+
240+
let cast = Cast::new();
241+
let args = make_args(ColumnarValue::Array(array));
242+
let result = cast.invoke_with_args(args).unwrap();
243+
244+
match result {
245+
ColumnarValue::Array(result_array) => {
246+
let ts_array = result_array.as_primitive::<TimestampMicrosecondType>();
247+
assert_eq!(ts_array.value(0), 0);
248+
assert_eq!(ts_array.value(1), 32_767_000_000);
249+
assert_eq!(ts_array.value(2), -32_768_000_000);
250+
assert!(ts_array.is_null(3));
251+
}
252+
_ => panic!("Expected array result"),
253+
}
254+
}
255+
256+
#[test]
257+
fn test_cast_int32_array_to_timestamp() {
258+
let array: ArrayRef =
259+
Arc::new(Int32Array::from(vec![Some(0), Some(1704067200), None]));
260+
261+
let cast = Cast::new();
262+
let args = make_args(ColumnarValue::Array(array));
263+
let result = cast.invoke_with_args(args).unwrap();
264+
265+
match result {
266+
ColumnarValue::Array(result_array) => {
267+
let ts_array = result_array.as_primitive::<TimestampMicrosecondType>();
268+
assert_eq!(ts_array.value(0), 0);
269+
assert_eq!(ts_array.value(1), 1_704_067_200_000_000);
270+
assert!(ts_array.is_null(2));
271+
}
272+
_ => panic!("Expected array result"),
273+
}
274+
}
275+
276+
#[test]
277+
fn test_cast_int64_array_overflow() {
278+
let array: ArrayRef =
279+
Arc::new(Int64Array::from(vec![Some(i64::MAX), Some(i64::MIN)]));
280+
281+
let cast = Cast::new();
282+
let args = make_args(ColumnarValue::Array(array));
283+
let result = cast.invoke_with_args(args).unwrap();
284+
285+
match result {
286+
ColumnarValue::Array(result_array) => {
287+
let ts_array = result_array.as_primitive::<TimestampMicrosecondType>();
288+
assert_eq!(ts_array.value(0), i64::MAX);
289+
assert_eq!(ts_array.value(1), i64::MIN);
290+
}
291+
_ => panic!("Expected array result"),
292+
}
293+
}
294+
295+
#[test]
296+
fn test_cast_scalar_int8() {
297+
let cast = Cast::new();
298+
let args = make_args(ColumnarValue::Scalar(ScalarValue::Int8(Some(100))));
299+
let result = cast.invoke_with_args(args).unwrap();
300+
assert_scalar_timestamp(result, 100_000_000);
301+
}
302+
303+
#[test]
304+
fn test_cast_scalar_int32() {
305+
let cast = Cast::new();
306+
let args = make_args(ColumnarValue::Scalar(ScalarValue::Int32(Some(1704067200))));
307+
let result = cast.invoke_with_args(args).unwrap();
308+
assert_scalar_timestamp(result, 1_704_067_200_000_000);
309+
}
310+
311+
#[test]
312+
fn test_cast_scalar_null() {
313+
let cast = Cast::new();
314+
let args = make_args(ColumnarValue::Scalar(ScalarValue::Int64(None)));
315+
let result = cast.invoke_with_args(args).unwrap();
316+
assert_scalar_null(result);
317+
}
318+
319+
#[test]
320+
fn test_cast_scalar_int64_overflow() {
321+
let cast = Cast::new();
322+
let args = make_args(ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MAX))));
323+
let result = cast.invoke_with_args(args).unwrap();
324+
assert_scalar_timestamp(result, i64::MAX);
325+
}
326+
327+
#[test]
328+
fn test_unsupported_scalar_type() {
329+
let cast = Cast::new();
330+
let args = make_args(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
331+
"2024-01-01".to_string(),
332+
))));
333+
let result = cast.invoke_with_args(args);
334+
assert!(result.is_err());
335+
assert!(
336+
result
337+
.unwrap_err()
338+
.to_string()
339+
.contains("Unsupported cast from")
340+
);
341+
}
342+
343+
#[test]
344+
fn test_unsupported_array_type() {
345+
let cast = Cast::new();
346+
let array: ArrayRef =
347+
Arc::new(arrow::array::Float32Array::from(vec![1.0, 2.0, 3.0]));
348+
let args = make_args(ColumnarValue::Array(array));
349+
let result = cast.invoke_with_args(args);
350+
assert!(result.is_err());
351+
assert!(
352+
result
353+
.unwrap_err()
354+
.to_string()
355+
.contains("Unsupported cast from")
356+
);
357+
}
358+
}

datafusion/spark/src/function/conversion/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod cast;
19+
20+
use cast::Cast;
1821
use datafusion_expr::ScalarUDF;
1922
use std::sync::Arc;
2023

2124
pub mod expr_fn {}
2225

2326
pub fn functions() -> Vec<Arc<ScalarUDF>> {
24-
vec![]
27+
vec![Arc::new(ScalarUDF::from(Cast::new()))]
2528
}

0 commit comments

Comments
 (0)