Skip to content

Commit 665d593

Browse files
committed
fix_review_comments
1 parent 532377e commit 665d593

File tree

3 files changed

+93
-36
lines changed

3 files changed

+93
-36
lines changed

datafusion/spark/src/function/conversion/cast.rs

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ use arrow::datatypes::{
2020
ArrowPrimitiveType, DataType, Field, FieldRef, Int8Type, Int16Type, Int32Type,
2121
Int64Type, TimeUnit,
2222
};
23-
use datafusion::logical_expr::{Coercion, TypeSignatureClass};
2423
use datafusion_common::config::ConfigOptions;
25-
use datafusion_common::types::logical_string;
26-
use datafusion_common::{
27-
Result as DataFusionResult, ScalarValue, exec_err, internal_err,
24+
use datafusion_common::types::{
25+
logical_int8, logical_int16, logical_int32, logical_int64, logical_string,
2826
};
29-
use datafusion_expr::TypeSignatureClass::Integer;
27+
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
28+
use datafusion_expr::{Coercion, TypeSignatureClass};
3029
use datafusion_expr::{
3130
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
3231
Signature, TypeSignature, Volatility,
@@ -81,21 +80,28 @@ impl SparkCast {
8180
}
8281

8382
pub fn new_with_config(config: &ConfigOptions) -> Self {
84-
// First arg: value to cast (only ints for now with potential to add further support later)
83+
// First arg: value to cast (only signed ints - Spark doesn't have unsigned integers)
8584
// Second arg: target datatype as Utf8 string literal (ex : 'timestamp')
86-
let int_arg = Coercion::new_exact(Integer);
8785
let string_arg =
8886
Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
87+
88+
// Spark only supports signed integers, so we explicitly list them
89+
let signed_int_signatures = [
90+
logical_int8(),
91+
logical_int16(),
92+
logical_int32(),
93+
logical_int64(),
94+
]
95+
.map(|int_type| {
96+
TypeSignature::Coercible(vec![
97+
Coercion::new_exact(TypeSignatureClass::Native(int_type)),
98+
string_arg.clone(),
99+
])
100+
});
101+
89102
Self {
90-
signature: Signature::one_of(
91-
vec![
92-
TypeSignature::Coercible(vec![int_arg.clone(), string_arg.clone()]),
93-
TypeSignature::Coercible(vec![
94-
int_arg,
95-
string_arg.clone(),
96-
string_arg,
97-
]),
98-
],
103+
signature: Signature::new(
104+
TypeSignature::OneOf(Vec::from(signed_int_signatures)),
99105
Volatility::Stable,
100106
),
101107
timezone: config
@@ -109,10 +115,7 @@ impl SparkCast {
109115
}
110116

111117
/// Parse target type string into a DataType
112-
fn parse_target_type(
113-
type_str: &str,
114-
timezone: Option<Arc<str>>,
115-
) -> DataFusionResult<DataType> {
118+
fn parse_target_type(type_str: &str, timezone: Option<Arc<str>>) -> Result<DataType> {
116119
match type_str.to_lowercase().as_str() {
117120
// further data type support in future
118121
"timestamp" => Ok(DataType::Timestamp(TimeUnit::Microsecond, timezone)),
@@ -127,7 +130,7 @@ fn parse_target_type(
127130
fn get_target_type_from_scalar_args(
128131
scalar_args: &[Option<&ScalarValue>],
129132
timezone: Option<Arc<str>>,
130-
) -> DataFusionResult<DataType> {
133+
) -> Result<DataType> {
131134
let type_arg = scalar_args.get(1).and_then(|opt| *opt);
132135

133136
match type_arg {
@@ -143,7 +146,7 @@ fn get_target_type_from_scalar_args(
143146
fn cast_int_to_timestamp<T: ArrowPrimitiveType>(
144147
array: &ArrayRef,
145148
timezone: Option<Arc<str>>,
146-
) -> DataFusionResult<ArrayRef>
149+
) -> Result<ArrayRef>
147150
where
148151
T::Native: Into<i64>,
149152
{
@@ -176,18 +179,15 @@ impl ScalarUDFImpl for SparkCast {
176179
&self.signature
177180
}
178181

179-
fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult<DataType> {
182+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
180183
internal_err!("return_field_from_args should be used instead")
181184
}
182185

183186
fn with_updated_config(&self, config: &ConfigOptions) -> Option<ScalarUDF> {
184187
Some(ScalarUDF::from(Self::new_with_config(config)))
185188
}
186189

187-
fn return_field_from_args(
188-
&self,
189-
args: ReturnFieldArgs,
190-
) -> DataFusionResult<FieldRef> {
190+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
191191
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
192192
let return_type = get_target_type_from_scalar_args(
193193
args.scalar_arguments,
@@ -196,10 +196,7 @@ impl ScalarUDFImpl for SparkCast {
196196
Ok(Arc::new(Field::new(self.name(), return_type, nullable)))
197197
}
198198

199-
fn invoke_with_args(
200-
&self,
201-
args: ScalarFunctionArgs,
202-
) -> DataFusionResult<ColumnarValue> {
199+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
203200
let target_type = args.return_field.data_type();
204201
match target_type {
205202
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
@@ -214,7 +211,7 @@ impl ScalarUDFImpl for SparkCast {
214211
fn cast_to_timestamp(
215212
input: &ColumnarValue,
216213
timezone: Option<Arc<str>>,
217-
) -> DataFusionResult<ColumnarValue> {
214+
) -> Result<ColumnarValue> {
218215
match input {
219216
ColumnarValue::Array(array) => match array.data_type() {
220217
DataType::Null => Ok(ColumnarValue::Array(Arc::new(

datafusion/spark/src/function/conversion/mod.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,23 @@
1818
mod cast;
1919

2020
use datafusion_expr::ScalarUDF;
21-
use datafusion_functions::make_udf_function;
21+
use datafusion_functions::make_udf_function_with_config;
2222
use std::sync::Arc;
2323

24-
make_udf_function!(cast::SparkCast, spark_cast);
24+
make_udf_function_with_config!(cast::SparkCast, spark_cast);
2525

2626
pub mod expr_fn {
2727
use datafusion_functions::export_functions;
2828

2929
export_functions!((
3030
spark_cast,
3131
"Casts given value to the specified type following Spark-compatible semantics",
32-
arg1 arg2
32+
@config arg1 arg2
3333
));
3434
}
3535

3636
pub fn functions() -> Vec<Arc<ScalarUDF>> {
37-
vec![spark_cast()]
37+
use datafusion_common::config::ConfigOptions;
38+
let config = ConfigOptions::default();
39+
vec![spark_cast(&config)]
3840
}

datafusion/sqllogictest/test_files/spark/conversion/cast_int_to_timestamp.slt

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,61 @@ SELECT spark_cast(1710057600::bigint, 'timestamp');
192192
# Reset to default UTC
193193
statement ok
194194
SET datafusion.execution.time_zone = 'UTC';
195+
196+
#############################
197+
# Array Tests
198+
#############################
199+
200+
# Create test table with 4 int columns: null, min, max, regular value
201+
statement ok
202+
CREATE TABLE int_test AS SELECT
203+
arrow_cast(column1, 'Int8') as i8_col,
204+
arrow_cast(column2, 'Int16') as i16_col,
205+
arrow_cast(column3, 'Int32') as i32_col,
206+
column4::bigint as i64_col
207+
FROM (VALUES
208+
(NULL, NULL, NULL, NULL),
209+
(-128, -32768, -2147483648, -86400),
210+
(127, 32767, 2147483647, 86400),
211+
(100, 3600, 1710054000, 1710054000)
212+
);
213+
214+
# Test in UTC
215+
query PPPP
216+
SELECT spark_cast(i8_col, 'timestamp'), spark_cast(i16_col, 'timestamp'), spark_cast(i32_col, 'timestamp'), spark_cast(i64_col, 'timestamp') FROM int_test;
217+
----
218+
NULL NULL NULL NULL
219+
1969-12-31T23:57:52Z 1969-12-31T14:53:52Z 1901-12-13T20:45:52Z 1969-12-31T00:00:00Z
220+
1970-01-01T00:02:07Z 1970-01-01T09:06:07Z 2038-01-19T03:14:07Z 1970-01-02T00:00:00Z
221+
1970-01-01T00:01:40Z 1970-01-01T01:00:00Z 2024-03-10T07:00:00Z 2024-03-10T07:00:00Z
222+
223+
# Test in America/Los_Angeles (PST - has DST)
224+
statement ok
225+
SET datafusion.execution.time_zone = 'America/Los_Angeles';
226+
227+
query PPPP
228+
SELECT spark_cast(i8_col, 'timestamp'), spark_cast(i16_col, 'timestamp'), spark_cast(i32_col, 'timestamp'), spark_cast(i64_col, 'timestamp') FROM int_test;
229+
----
230+
NULL NULL NULL NULL
231+
1969-12-31T15:57:52-08:00 1969-12-31T06:53:52-08:00 1901-12-13T12:45:52-08:00 1969-12-30T16:00:00-08:00
232+
1969-12-31T16:02:07-08:00 1970-01-01T01:06:07-08:00 2038-01-18T19:14:07-08:00 1970-01-01T16:00:00-08:00
233+
1969-12-31T16:01:40-08:00 1969-12-31T17:00:00-08:00 2024-03-09T23:00:00-08:00 2024-03-09T23:00:00-08:00
234+
235+
# Test in America/Phoenix (MST - no DST, always UTC-7)
236+
statement ok
237+
SET datafusion.execution.time_zone = 'America/Phoenix';
238+
239+
query PPPP
240+
SELECT spark_cast(i8_col, 'timestamp'), spark_cast(i16_col, 'timestamp'), spark_cast(i32_col, 'timestamp'), spark_cast(i64_col, 'timestamp') FROM int_test;
241+
----
242+
NULL NULL NULL NULL
243+
1969-12-31T16:57:52-07:00 1969-12-31T07:53:52-07:00 1901-12-13T13:45:52-07:00 1969-12-30T17:00:00-07:00
244+
1969-12-31T17:02:07-07:00 1970-01-01T02:06:07-07:00 2038-01-18T20:14:07-07:00 1970-01-01T17:00:00-07:00
245+
1969-12-31T17:01:40-07:00 1969-12-31T18:00:00-07:00 2024-03-10T00:00:00-07:00 2024-03-10T00:00:00-07:00
246+
247+
# Reset and cleanup
248+
statement ok
249+
SET datafusion.execution.time_zone = 'UTC';
250+
251+
statement ok
252+
DROP TABLE int_test;

0 commit comments

Comments
 (0)