@@ -20,20 +20,23 @@ use arrow::datatypes::{
2020 ArrowPrimitiveType , DataType , Field , FieldRef , Int8Type , Int16Type , Int32Type ,
2121 Int64Type , TimeUnit ,
2222} ;
23- use datafusion_common:: utils:: take_function_args;
23+ use datafusion:: logical_expr:: { Coercion , TypeSignatureClass } ;
24+ use datafusion_common:: config:: ConfigOptions ;
25+ use datafusion_common:: types:: logical_string;
2426use datafusion_common:: {
2527 Result as DataFusionResult , ScalarValue , exec_err, internal_err,
2628} ;
27- use datafusion_expr:: { ColumnarValue , Expr , ReturnFieldArgs , ScalarFunctionArgs , ScalarUDFImpl , Signature , TypeSignature , Volatility } ;
29+ use datafusion_expr:: TypeSignatureClass :: Integer ;
30+ use datafusion_expr:: {
31+ ColumnarValue , ReturnFieldArgs , ScalarFunctionArgs , ScalarUDF , ScalarUDFImpl ,
32+ Signature , TypeSignature , Volatility ,
33+ } ;
2834use std:: any:: Any ;
2935use std:: sync:: Arc ;
30- use datafusion:: logical_expr:: { Coercion , TypeSignatureClass } ;
31- use datafusion_common:: types:: { logical_int64, logical_string} ;
32- use datafusion_expr:: simplify:: { ExprSimplifyResult , SimplifyContext } ;
3336
3437const MICROS_PER_SECOND : i64 = 1_000_000 ;
3538
36- /// Convert seconds to microseconds with saturating overflow behavior
39+ /// Convert seconds to microseconds with saturating overflow behavior (matches spark spec)
3740#[ inline]
3841fn secs_to_micros ( secs : i64 ) -> i64 {
3942 secs. saturating_mul ( MICROS_PER_SECOND )
@@ -63,6 +66,7 @@ fn secs_to_micros(secs: i64) -> i64 {
6366#[ derive( Debug , PartialEq , Eq , Hash ) ]
6467pub struct SparkCast {
6568 signature : Signature ,
69+ timezone : Option < Arc < str > > ,
6670}
6771
6872impl Default for SparkCast {
@@ -73,24 +77,45 @@ impl Default for SparkCast {
7377
7478impl SparkCast {
7579 pub fn new ( ) -> Self {
80+ Self :: new_with_config ( & ConfigOptions :: default ( ) )
81+ }
82+
83+ pub fn new_with_config ( config : & ConfigOptions ) -> Self {
7684 // First arg: value to cast (only ints for now with potential to add further support later)
7785 // Second arg: target datatype as Utf8 string literal (ex : 'timestamp')
78- let int_arg = Coercion :: new_exact ( TypeSignatureClass :: Native ( logical_int64 ( ) ) ) ;
79- let string_arg = Coercion :: new_exact ( TypeSignatureClass :: Native ( logical_string ( ) ) ) ;
86+ let int_arg = Coercion :: new_exact ( Integer ) ;
87+ let string_arg =
88+ Coercion :: new_exact ( TypeSignatureClass :: Native ( logical_string ( ) ) ) ;
8089 Self {
8190 signature : Signature :: one_of (
82- vec ! [ TypeSignature :: Coercible ( vec![ int_arg, string_arg] ) ] ,
91+ vec ! [
92+ TypeSignature :: Coercible ( vec![ int_arg. clone( ) , string_arg. clone( ) ] ) ,
93+ TypeSignature :: Coercible ( vec![
94+ int_arg,
95+ string_arg. clone( ) ,
96+ string_arg,
97+ ] ) ,
98+ ] ,
8399 Volatility :: Stable ,
84100 ) ,
101+ timezone : config
102+ . execution
103+ . time_zone
104+ . as_ref ( )
105+ . map ( |tz| Arc :: from ( tz. as_str ( ) ) )
106+ . or_else ( || Some ( Arc :: from ( "UTC" ) ) ) ,
85107 }
86108 }
87109}
88110
89111/// Parse target type string into a DataType
90- fn parse_target_type ( type_str : & str ) -> DataFusionResult < DataType > {
112+ fn parse_target_type (
113+ type_str : & str ,
114+ timezone : Option < Arc < str > > ,
115+ ) -> DataFusionResult < DataType > {
91116 match type_str. to_lowercase ( ) . as_str ( ) {
92117 // further data type support in future
93- "timestamp" => Ok ( DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ) ,
118+ "timestamp" => Ok ( DataType :: Timestamp ( TimeUnit :: Microsecond , timezone ) ) ,
94119 other => exec_err ! (
95120 "Unsupported spark_cast target type '{}'. Supported types: timestamp" ,
96121 other
@@ -101,13 +126,14 @@ fn parse_target_type(type_str: &str) -> DataFusionResult<DataType> {
101126/// Extract target type string from scalar arguments
102127fn get_target_type_from_scalar_args (
103128 scalar_args : & [ Option < & ScalarValue > ] ,
129+ timezone : Option < Arc < str > > ,
104130) -> DataFusionResult < DataType > {
105- let [ _ , type_arg] = take_function_args ( "spark_cast" , scalar_args ) ? ;
131+ let type_arg = scalar_args . get ( 1 ) . and_then ( |opt| * opt ) ;
106132
107133 match type_arg {
108- Some ( ScalarValue :: Utf8 ( Some ( s) ) ) | Some ( ScalarValue :: LargeUtf8 ( Some ( s ) ) ) => {
109- parse_target_type ( s )
110- }
134+ Some ( ScalarValue :: Utf8 ( Some ( s) ) )
135+ | Some ( ScalarValue :: LargeUtf8 ( Some ( s ) ) )
136+ | Some ( ScalarValue :: Utf8View ( Some ( s ) ) ) => parse_target_type ( s , timezone ) ,
111137 _ => exec_err ! (
112138 "spark_cast requires second argument to be a string of target data type ex: timestamp"
113139 ) ,
@@ -154,23 +180,30 @@ impl ScalarUDFImpl for SparkCast {
154180 internal_err ! ( "return_field_from_args should be used instead" )
155181 }
156182
183+ fn with_updated_config ( & self , config : & ConfigOptions ) -> Option < ScalarUDF > {
184+ Some ( ScalarUDF :: from ( Self :: new_with_config ( config) ) )
185+ }
186+
187+ fn return_field_from_args (
188+ & self ,
189+ args : ReturnFieldArgs ,
190+ ) -> DataFusionResult < FieldRef > {
191+ let nullable = args. arg_fields . iter ( ) . any ( |f| f. is_nullable ( ) ) ;
192+ let return_type = get_target_type_from_scalar_args (
193+ args. scalar_arguments ,
194+ self . timezone . clone ( ) ,
195+ ) ?;
196+ Ok ( Arc :: new ( Field :: new ( self . name ( ) , return_type, nullable) ) )
197+ }
198+
157199 fn invoke_with_args (
158200 & self ,
159201 args : ScalarFunctionArgs ,
160202 ) -> DataFusionResult < ColumnarValue > {
161203 let target_type = args. return_field . data_type ( ) ;
162- // Use session timezone, fallback to UTC if not set
163- let session_tz: Arc < str > = args
164- . config_options
165- . execution
166- . time_zone
167- . clone ( )
168- . map ( |s| Arc :: from ( s. as_str ( ) ) )
169- . unwrap_or_else ( || Arc :: from ( "UTC" ) ) ;
170-
171204 match target_type {
172- DataType :: Timestamp ( TimeUnit :: Microsecond , _ ) => {
173- cast_to_timestamp ( & args. args [ 0 ] , Some ( session_tz ) )
205+ DataType :: Timestamp ( TimeUnit :: Microsecond , tz ) => {
206+ cast_to_timestamp ( & args. args [ 0 ] , tz . clone ( ) )
174207 }
175208 other => exec_err ! ( "Unsupported spark_cast target type: {:?}" , other) ,
176209 }
@@ -232,7 +265,7 @@ mod tests {
232265
233266 // helpers to make testing easier
234267 fn make_args ( input : ColumnarValue , target_type : & str ) -> ScalarFunctionArgs {
235- make_args_with_timezone ( input, target_type, None )
268+ make_args_with_timezone ( input, target_type, Some ( "UTC" ) )
236269 }
237270
238271 fn make_args_with_timezone (
@@ -242,10 +275,13 @@ mod tests {
242275 ) -> ScalarFunctionArgs {
243276 let return_field = Arc :: new ( Field :: new (
244277 "result" ,
245- DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ,
278+ DataType :: Timestamp (
279+ TimeUnit :: Microsecond ,
280+ Some ( Arc :: from ( timezone. unwrap ( ) ) ) ,
281+ ) ,
246282 true ,
247283 ) ) ;
248- let mut config = datafusion_common :: config :: ConfigOptions :: default ( ) ;
284+ let mut config = ConfigOptions :: default ( ) ;
249285 if let Some ( tz) = timezone {
250286 config. execution . time_zone = Some ( tz. to_string ( ) ) ;
251287 }
0 commit comments