Skip to content

Commit 532377e

Browse files
committed
fix_review_comments
1 parent 228a24e commit 532377e

File tree

2 files changed

+145
-43
lines changed

2 files changed

+145
-43
lines changed

datafusion/spark/src/function/conversion/cast.rs

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,23 @@ use arrow::datatypes::{
2020
ArrowPrimitiveType, DataType, Field, FieldRef, Int8Type, Int16Type, Int32Type,
2121
Int64Type, TimeUnit,
2222
};
23-
use datafusion_common::utils::take_function_args;
23+
use datafusion::logical_expr::{Coercion, TypeSignatureClass};
24+
use datafusion_common::config::ConfigOptions;
25+
use datafusion_common::types::logical_string;
2426
use datafusion_common::{
2527
Result as DataFusionResult, ScalarValue, exec_err, internal_err,
2628
};
27-
use datafusion_expr::{ColumnarValue, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility};
29+
use datafusion_expr::TypeSignatureClass::Integer;
30+
use datafusion_expr::{
31+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
32+
Signature, TypeSignature, Volatility,
33+
};
2834
use std::any::Any;
2935
use std::sync::Arc;
30-
use datafusion::logical_expr::{Coercion, TypeSignatureClass};
31-
use datafusion_common::types::{logical_int64, logical_string};
32-
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
3336

3437
const MICROS_PER_SECOND: i64 = 1_000_000;
3538

36-
/// Convert seconds to microseconds with saturating overflow behavior
39+
/// Convert seconds to microseconds with saturating overflow behavior (matches spark spec)
3740
#[inline]
3841
fn secs_to_micros(secs: i64) -> i64 {
3942
secs.saturating_mul(MICROS_PER_SECOND)
@@ -63,6 +66,7 @@ fn secs_to_micros(secs: i64) -> i64 {
6366
#[derive(Debug, PartialEq, Eq, Hash)]
6467
pub struct SparkCast {
6568
signature: Signature,
69+
timezone: Option<Arc<str>>,
6670
}
6771

6872
impl Default for SparkCast {
@@ -73,24 +77,45 @@ impl Default for SparkCast {
7377

7478
impl SparkCast {
7579
pub fn new() -> Self {
80+
Self::new_with_config(&ConfigOptions::default())
81+
}
82+
83+
pub fn new_with_config(config: &ConfigOptions) -> Self {
7684
// First arg: value to cast (only ints for now with potential to add further support later)
7785
// Second arg: target datatype as Utf8 string literal (ex : 'timestamp')
78-
let int_arg = Coercion::new_exact(TypeSignatureClass::Native(logical_int64()));
79-
let string_arg = Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
86+
let int_arg = Coercion::new_exact(Integer);
87+
let string_arg =
88+
Coercion::new_exact(TypeSignatureClass::Native(logical_string()));
8089
Self {
8190
signature: Signature::one_of(
82-
vec![TypeSignature::Coercible(vec![int_arg, string_arg])],
91+
vec![
92+
TypeSignature::Coercible(vec![int_arg.clone(), string_arg.clone()]),
93+
TypeSignature::Coercible(vec![
94+
int_arg,
95+
string_arg.clone(),
96+
string_arg,
97+
]),
98+
],
8399
Volatility::Stable,
84100
),
101+
timezone: config
102+
.execution
103+
.time_zone
104+
.as_ref()
105+
.map(|tz| Arc::from(tz.as_str()))
106+
.or_else(|| Some(Arc::from("UTC"))),
85107
}
86108
}
87109
}
88110

89111
/// Parse target type string into a DataType
90-
fn parse_target_type(type_str: &str) -> DataFusionResult<DataType> {
112+
fn parse_target_type(
113+
type_str: &str,
114+
timezone: Option<Arc<str>>,
115+
) -> DataFusionResult<DataType> {
91116
match type_str.to_lowercase().as_str() {
92117
// further data type support in future
93-
"timestamp" => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)),
118+
"timestamp" => Ok(DataType::Timestamp(TimeUnit::Microsecond, timezone)),
94119
other => exec_err!(
95120
"Unsupported spark_cast target type '{}'. Supported types: timestamp",
96121
other
@@ -101,13 +126,14 @@ fn parse_target_type(type_str: &str) -> DataFusionResult<DataType> {
101126
/// Extract target type string from scalar arguments
102127
fn get_target_type_from_scalar_args(
103128
scalar_args: &[Option<&ScalarValue>],
129+
timezone: Option<Arc<str>>,
104130
) -> DataFusionResult<DataType> {
105-
let [_, type_arg] = take_function_args("spark_cast", scalar_args)?;
131+
let type_arg = scalar_args.get(1).and_then(|opt| *opt);
106132

107133
match type_arg {
108-
Some(ScalarValue::Utf8(Some(s))) | Some(ScalarValue::LargeUtf8(Some(s))) => {
109-
parse_target_type(s)
110-
}
134+
Some(ScalarValue::Utf8(Some(s)))
135+
| Some(ScalarValue::LargeUtf8(Some(s)))
136+
| Some(ScalarValue::Utf8View(Some(s))) => parse_target_type(s, timezone),
111137
_ => exec_err!(
112138
"spark_cast requires second argument to be a string of target data type ex: timestamp"
113139
),
@@ -154,23 +180,30 @@ impl ScalarUDFImpl for SparkCast {
154180
internal_err!("return_field_from_args should be used instead")
155181
}
156182

183+
fn with_updated_config(&self, config: &ConfigOptions) -> Option<ScalarUDF> {
184+
Some(ScalarUDF::from(Self::new_with_config(config)))
185+
}
186+
187+
fn return_field_from_args(
188+
&self,
189+
args: ReturnFieldArgs,
190+
) -> DataFusionResult<FieldRef> {
191+
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
192+
let return_type = get_target_type_from_scalar_args(
193+
args.scalar_arguments,
194+
self.timezone.clone(),
195+
)?;
196+
Ok(Arc::new(Field::new(self.name(), return_type, nullable)))
197+
}
198+
157199
fn invoke_with_args(
158200
&self,
159201
args: ScalarFunctionArgs,
160202
) -> DataFusionResult<ColumnarValue> {
161203
let target_type = args.return_field.data_type();
162-
// Use session timezone, fallback to UTC if not set
163-
let session_tz: Arc<str> = args
164-
.config_options
165-
.execution
166-
.time_zone
167-
.clone()
168-
.map(|s| Arc::from(s.as_str()))
169-
.unwrap_or_else(|| Arc::from("UTC"));
170-
171204
match target_type {
172-
DataType::Timestamp(TimeUnit::Microsecond, _) => {
173-
cast_to_timestamp(&args.args[0], Some(session_tz))
205+
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
206+
cast_to_timestamp(&args.args[0], tz.clone())
174207
}
175208
other => exec_err!("Unsupported spark_cast target type: {:?}", other),
176209
}
@@ -232,7 +265,7 @@ mod tests {
232265

233266
// helpers to make testing easier
234267
fn make_args(input: ColumnarValue, target_type: &str) -> ScalarFunctionArgs {
235-
make_args_with_timezone(input, target_type, None)
268+
make_args_with_timezone(input, target_type, Some("UTC"))
236269
}
237270

238271
fn make_args_with_timezone(
@@ -242,10 +275,13 @@ mod tests {
242275
) -> ScalarFunctionArgs {
243276
let return_field = Arc::new(Field::new(
244277
"result",
245-
DataType::Timestamp(TimeUnit::Microsecond, None),
278+
DataType::Timestamp(
279+
TimeUnit::Microsecond,
280+
Some(Arc::from(timezone.unwrap())),
281+
),
246282
true,
247283
));
248-
let mut config = datafusion_common::config::ConfigOptions::default();
284+
let mut config = ConfigOptions::default();
249285
if let Some(tz) = timezone {
250286
config.execution.time_zone = Some(tz.to_string());
251287
}

datafusion/sqllogictest/test_files/spark/conversion/cast_int_to_timestamp.slt

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,50 +19,50 @@
1919
query P
2020
SELECT spark_cast(arrow_cast(0, 'Int8'), 'timestamp');
2121
----
22-
1970-01-01T00:00:00
22+
1970-01-01T00:00:00Z
2323

2424
query P
2525
SELECT spark_cast(arrow_cast(1, 'Int8'), 'timestamp');
2626
----
27-
1970-01-01T00:00:01
27+
1970-01-01T00:00:01Z
2828

2929
query P
3030
SELECT spark_cast(arrow_cast(-1, 'Int8'), 'timestamp');
3131
----
32-
1969-12-31T23:59:59
32+
1969-12-31T23:59:59Z
3333

3434
# Test spark_cast from int16 to timestamp
3535
query P
3636
SELECT spark_cast(arrow_cast(0, 'Int16'), 'timestamp');
3737
----
38-
1970-01-01T00:00:00
38+
1970-01-01T00:00:00Z
3939

4040
query P
4141
SELECT spark_cast(arrow_cast(3600, 'Int16'), 'timestamp');
4242
----
43-
1970-01-01T01:00:00
43+
1970-01-01T01:00:00Z
4444

4545
# Test spark_cast from int32 to timestamp
4646
query P
4747
SELECT spark_cast(arrow_cast(0, 'Int32'), 'timestamp');
4848
----
49-
1970-01-01T00:00:00
49+
1970-01-01T00:00:00Z
5050

5151
query P
5252
SELECT spark_cast(arrow_cast(1704067200, 'Int32'), 'timestamp');
5353
----
54-
2024-01-01T00:00:00
54+
2024-01-01T00:00:00Z
5555

5656
# Test spark_cast from int64 to timestamp
5757
query P
5858
SELECT spark_cast(0::bigint, 'timestamp');
5959
----
60-
1970-01-01T00:00:00
60+
1970-01-01T00:00:00Z
6161

6262
query P
6363
SELECT spark_cast(1704067200::bigint, 'timestamp');
6464
----
65-
2024-01-01T00:00:00
65+
2024-01-01T00:00:00Z
6666

6767
# Test NULL handling
6868
query P
@@ -95,29 +95,29 @@ NULL
9595
query P
9696
SELECT spark_cast(arrow_cast(127, 'Int8'), 'timestamp');
9797
----
98-
1970-01-01T00:02:07
98+
1970-01-01T00:02:07Z
9999

100100
query P
101101
SELECT spark_cast(arrow_cast(-128, 'Int8'), 'timestamp');
102102
----
103-
1969-12-31T23:57:52
103+
1969-12-31T23:57:52Z
104104

105105
# Test Int16 boundary values
106106
query P
107107
SELECT spark_cast(arrow_cast(32767, 'Int16'), 'timestamp');
108108
----
109-
1970-01-01T09:06:07
109+
1970-01-01T09:06:07Z
110110

111111
query P
112112
SELECT spark_cast(arrow_cast(-32768, 'Int16'), 'timestamp');
113113
----
114-
1969-12-31T14:53:52
114+
1969-12-31T14:53:52Z
115115

116116
# Test Int64 negative value
117117
query P
118118
SELECT spark_cast(-86400::bigint, 'timestamp');
119119
----
120-
1969-12-31T00:00:00
120+
1969-12-31T00:00:00Z
121121

122122
# Test unsupported source type - should error
123123
statement error
@@ -126,3 +126,69 @@ SELECT spark_cast('2024-01-01', 'timestamp');
126126
# Test unsupported target type - should error
127127
statement error
128128
SELECT spark_cast(100, 'string');
129+
130+
# Test with different session timezones to verify simplify() picks up config
131+
132+
# America/Los_Angeles (PST/PDT - has DST)
133+
statement ok
134+
SET datafusion.execution.time_zone = 'America/Los_Angeles';
135+
136+
# Epoch in PST (UTC-8)
137+
query P
138+
SELECT spark_cast(0::bigint, 'timestamp');
139+
----
140+
1969-12-31T16:00:00-08:00
141+
142+
# 2024-01-01 00:00:00 UTC in PST (winter, UTC-8)
143+
query P
144+
SELECT spark_cast(1704067200::bigint, 'timestamp');
145+
----
146+
2023-12-31T16:00:00-08:00
147+
148+
# America/Phoenix (MST - no DST, always UTC-7)
149+
statement ok
150+
SET datafusion.execution.time_zone = 'America/Phoenix';
151+
152+
# Epoch in Phoenix (UTC-7)
153+
query P
154+
SELECT spark_cast(0::bigint, 'timestamp');
155+
----
156+
1969-12-31T17:00:00-07:00
157+
158+
# 2024-01-01 00:00:00 UTC in Phoenix (still UTC-7, no DST)
159+
query P
160+
SELECT spark_cast(1704067200::bigint, 'timestamp');
161+
----
162+
2023-12-31T17:00:00-07:00
163+
164+
# Test with different timezones - LA (has DST)
165+
statement ok
166+
SET datafusion.execution.time_zone = 'America/Los_Angeles';
167+
168+
query P
169+
SELECT spark_cast(1710054000::bigint, 'timestamp');
170+
----
171+
2024-03-09T23:00:00-08:00
172+
173+
query P
174+
SELECT spark_cast(1710057600::bigint, 'timestamp');
175+
----
176+
2024-03-10T00:00:00-08:00
177+
178+
# Phoenix has no DST - always UTC-7
179+
statement ok
180+
SET datafusion.execution.time_zone = 'America/Phoenix';
181+
182+
query P
183+
SELECT spark_cast(1710054000::bigint, 'timestamp');
184+
----
185+
2024-03-10T00:00:00-07:00
186+
187+
query P
188+
SELECT spark_cast(1710057600::bigint, 'timestamp');
189+
----
190+
2024-03-10T01:00:00-07:00
191+
192+
# Reset to default UTC
193+
statement ok
194+
SET datafusion.execution.time_zone = 'UTC';

0 commit comments

Comments
 (0)