@@ -20,13 +20,12 @@ use arrow::datatypes::{
2020 ArrowPrimitiveType , DataType , Field , FieldRef , Int8Type , Int16Type , Int32Type ,
2121 Int64Type , TimeUnit ,
2222} ;
23- use datafusion:: logical_expr:: { Coercion , TypeSignatureClass } ;
2423use datafusion_common:: config:: ConfigOptions ;
25- use datafusion_common:: types:: logical_string;
26- use datafusion_common:: {
27- Result as DataFusionResult , ScalarValue , exec_err, internal_err,
24+ use datafusion_common:: types:: {
25+ logical_int8, logical_int16, logical_int32, logical_int64, logical_string,
2826} ;
29- use datafusion_expr:: TypeSignatureClass :: Integer ;
27+ use datafusion_common:: { Result , ScalarValue , exec_err, internal_err} ;
28+ use datafusion_expr:: { Coercion , TypeSignatureClass } ;
3029use datafusion_expr:: {
3130 ColumnarValue , ReturnFieldArgs , ScalarFunctionArgs , ScalarUDF , ScalarUDFImpl ,
3231 Signature , TypeSignature , Volatility ,
@@ -81,21 +80,28 @@ impl SparkCast {
8180 }
8281
8382 pub fn new_with_config ( config : & ConfigOptions ) -> Self {
84- // First arg: value to cast (only ints for now with potential to add further support later )
83+ // First arg: value to cast (only signed ints - Spark doesn't have unsigned integers )
8584 // Second arg: target datatype as Utf8 string literal (ex : 'timestamp')
86- let int_arg = Coercion :: new_exact ( Integer ) ;
8785 let string_arg =
8886 Coercion :: new_exact ( TypeSignatureClass :: Native ( logical_string ( ) ) ) ;
87+
88+ // Spark only supports signed integers, so we explicitly list them
89+ let signed_int_signatures = [
90+ logical_int8 ( ) ,
91+ logical_int16 ( ) ,
92+ logical_int32 ( ) ,
93+ logical_int64 ( ) ,
94+ ]
95+ . map ( |int_type| {
96+ TypeSignature :: Coercible ( vec ! [
97+ Coercion :: new_exact( TypeSignatureClass :: Native ( int_type) ) ,
98+ string_arg. clone( ) ,
99+ ] )
100+ } ) ;
101+
89102 Self {
90- signature : Signature :: one_of (
91- vec ! [
92- TypeSignature :: Coercible ( vec![ int_arg. clone( ) , string_arg. clone( ) ] ) ,
93- TypeSignature :: Coercible ( vec![
94- int_arg,
95- string_arg. clone( ) ,
96- string_arg,
97- ] ) ,
98- ] ,
103+ signature : Signature :: new (
104+ TypeSignature :: OneOf ( Vec :: from ( signed_int_signatures) ) ,
99105 Volatility :: Stable ,
100106 ) ,
101107 timezone : config
@@ -109,10 +115,7 @@ impl SparkCast {
109115}
110116
111117/// Parse target type string into a DataType
112- fn parse_target_type (
113- type_str : & str ,
114- timezone : Option < Arc < str > > ,
115- ) -> DataFusionResult < DataType > {
118+ fn parse_target_type ( type_str : & str , timezone : Option < Arc < str > > ) -> Result < DataType > {
116119 match type_str. to_lowercase ( ) . as_str ( ) {
117120 // further data type support in future
118121 "timestamp" => Ok ( DataType :: Timestamp ( TimeUnit :: Microsecond , timezone) ) ,
@@ -127,7 +130,7 @@ fn parse_target_type(
127130fn get_target_type_from_scalar_args (
128131 scalar_args : & [ Option < & ScalarValue > ] ,
129132 timezone : Option < Arc < str > > ,
130- ) -> DataFusionResult < DataType > {
133+ ) -> Result < DataType > {
131134 let type_arg = scalar_args. get ( 1 ) . and_then ( |opt| * opt) ;
132135
133136 match type_arg {
@@ -143,7 +146,7 @@ fn get_target_type_from_scalar_args(
143146fn cast_int_to_timestamp < T : ArrowPrimitiveType > (
144147 array : & ArrayRef ,
145148 timezone : Option < Arc < str > > ,
146- ) -> DataFusionResult < ArrayRef >
149+ ) -> Result < ArrayRef >
147150where
148151 T :: Native : Into < i64 > ,
149152{
@@ -176,18 +179,15 @@ impl ScalarUDFImpl for SparkCast {
176179 & self . signature
177180 }
178181
179- fn return_type ( & self , _arg_types : & [ DataType ] ) -> DataFusionResult < DataType > {
182+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
180183 internal_err ! ( "return_field_from_args should be used instead" )
181184 }
182185
183186 fn with_updated_config ( & self , config : & ConfigOptions ) -> Option < ScalarUDF > {
184187 Some ( ScalarUDF :: from ( Self :: new_with_config ( config) ) )
185188 }
186189
187- fn return_field_from_args (
188- & self ,
189- args : ReturnFieldArgs ,
190- ) -> DataFusionResult < FieldRef > {
190+ fn return_field_from_args ( & self , args : ReturnFieldArgs ) -> Result < FieldRef > {
191191 let nullable = args. arg_fields . iter ( ) . any ( |f| f. is_nullable ( ) ) ;
192192 let return_type = get_target_type_from_scalar_args (
193193 args. scalar_arguments ,
@@ -196,10 +196,7 @@ impl ScalarUDFImpl for SparkCast {
196196 Ok ( Arc :: new ( Field :: new ( self . name ( ) , return_type, nullable) ) )
197197 }
198198
199- fn invoke_with_args (
200- & self ,
201- args : ScalarFunctionArgs ,
202- ) -> DataFusionResult < ColumnarValue > {
199+ fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
203200 let target_type = args. return_field . data_type ( ) ;
204201 match target_type {
205202 DataType :: Timestamp ( TimeUnit :: Microsecond , tz) => {
@@ -214,7 +211,7 @@ impl ScalarUDFImpl for SparkCast {
214211fn cast_to_timestamp (
215212 input : & ColumnarValue ,
216213 timezone : Option < Arc < str > > ,
217- ) -> DataFusionResult < ColumnarValue > {
214+ ) -> Result < ColumnarValue > {
218215 match input {
219216 ColumnarValue :: Array ( array) => match array. data_type ( ) {
220217 DataType :: Null => Ok ( ColumnarValue :: Array ( Arc :: new (
0 commit comments