1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ use crate :: conversion_funcs:: boolean:: can_cast_from_boolean;
19+ use crate :: conversion_funcs:: utils:: spark_cast_postprocess;
1820use crate :: utils:: array_with_timezone;
1921use crate :: { timezone, BinaryOutputStyle } ;
2022use crate :: { EvalMode , SparkError , SparkResult } ;
@@ -36,7 +38,7 @@ use arrow::{
3638 GenericStringArray , Int16Array , Int32Array , Int64Array , Int8Array , OffsetSizeTrait ,
3739 PrimitiveArray ,
3840 } ,
39- compute:: { cast_with_options, take, unary , CastOptions } ,
41+ compute:: { cast_with_options, take, CastOptions } ,
4042 datatypes:: {
4143 is_validate_decimal_precision, ArrowPrimitiveType , Decimal128Type , Float32Type ,
4244 Float64Type , Int64Type , TimestampMicrosecondType ,
@@ -47,16 +49,10 @@ use arrow::{
4749} ;
4850use base64:: prelude:: * ;
4951use chrono:: { DateTime , NaiveDate , TimeZone , Timelike } ;
50- use datafusion:: common:: {
51- cast:: as_generic_string_array, internal_err, DataFusionError , Result as DataFusionResult ,
52- ScalarValue ,
53- } ;
52+ use datafusion:: common:: { internal_err, DataFusionError , Result as DataFusionResult , ScalarValue } ;
5453use datafusion:: physical_expr:: PhysicalExpr ;
5554use datafusion:: physical_plan:: ColumnarValue ;
56- use num:: {
57- cast:: AsPrimitive , integer:: div_floor, traits:: CheckedNeg , CheckedSub , Integer , ToPrimitive ,
58- Zero ,
59- } ;
55+ use num:: { cast:: AsPrimitive , traits:: CheckedNeg , CheckedSub , Integer , ToPrimitive , Zero } ;
6056use regex:: Regex ;
6157use std:: str:: FromStr ;
6258use std:: {
@@ -69,8 +65,6 @@ use std::{
6965
7066static TIMESTAMP_FORMAT : Option < & str > = Some ( "%Y-%m-%d %H:%M:%S%.f" ) ;
7167
72- const MICROS_PER_SECOND : i64 = 1000000 ;
73-
7468static CAST_OPTIONS : CastOptions = CastOptions {
7569 safe : true ,
7670 format_options : FormatOptions :: new ( )
@@ -187,7 +181,7 @@ pub fn cast_supported(
187181 }
188182
189183 match ( from_type, to_type) {
190- ( Boolean , _) => can_cast_from_boolean ( to_type, options ) ,
184+ ( Boolean , _) => can_cast_from_boolean ( to_type) ,
191185 ( UInt8 | UInt16 | UInt32 | UInt64 , Int8 | Int16 | Int32 | Int64 )
192186 if options. allow_cast_unsigned_ints =>
193187 {
@@ -302,11 +296,6 @@ fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> b
302296 }
303297}
304298
305- fn can_cast_from_boolean ( to_type : & DataType , _: & SparkCastOptions ) -> bool {
306- use DataType :: * ;
307- matches ! ( to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64 )
308- }
309-
310299fn can_cast_from_byte ( to_type : & DataType , _: & SparkCastOptions ) -> bool {
311300 use DataType :: * ;
312301 matches ! (
@@ -1321,16 +1310,7 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b
13211310 DataType :: Null => {
13221311 matches ! ( to_type, DataType :: List ( _) )
13231312 }
1324- DataType :: Boolean => matches ! (
1325- to_type,
1326- DataType :: Int8
1327- | DataType :: Int16
1328- | DataType :: Int32
1329- | DataType :: Int64
1330- | DataType :: Float32
1331- | DataType :: Float64
1332- | DataType :: Utf8
1333- ) ,
1313+ DataType :: Boolean => can_cast_from_boolean ( to_type) ,
13341314 DataType :: Int8 | DataType :: Int16 | DataType :: Int32 | DataType :: Int64 => {
13351315 matches ! (
13361316 to_type,
@@ -2987,84 +2967,6 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>>
29872967 }
29882968}
29892969
2990- /// This takes for special casting cases of Spark. E.g., Timestamp to Long.
2991- /// This function runs as a post process of the DataFusion cast(). By the time it arrives here,
2992- /// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify
2993- /// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in
2994- /// expressions/cast.rs, so it can be still Dictionary.
2995- fn spark_cast_postprocess ( array : ArrayRef , from_type : & DataType , to_type : & DataType ) -> ArrayRef {
2996- match ( from_type, to_type) {
2997- ( DataType :: Timestamp ( _, _) , DataType :: Int64 ) => {
2998- // See Spark's `Cast` expression
2999- unary_dyn :: < _ , Int64Type > ( & array, |v| div_floor ( v, MICROS_PER_SECOND ) ) . unwrap ( )
3000- }
3001- ( DataType :: Dictionary ( _, value_type) , DataType :: Int64 )
3002- if matches ! ( value_type. as_ref( ) , & DataType :: Timestamp ( _, _) ) =>
3003- {
3004- // See Spark's `Cast` expression
3005- unary_dyn :: < _ , Int64Type > ( & array, |v| div_floor ( v, MICROS_PER_SECOND ) ) . unwrap ( )
3006- }
3007- ( DataType :: Timestamp ( _, _) , DataType :: Utf8 ) => remove_trailing_zeroes ( array) ,
3008- ( DataType :: Dictionary ( _, value_type) , DataType :: Utf8 )
3009- if matches ! ( value_type. as_ref( ) , & DataType :: Timestamp ( _, _) ) =>
3010- {
3011- remove_trailing_zeroes ( array)
3012- }
3013- _ => array,
3014- }
3015- }
3016-
3017- /// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
3018- fn unary_dyn < F , T > ( array : & ArrayRef , op : F ) -> Result < ArrayRef , ArrowError >
3019- where
3020- T : ArrowPrimitiveType ,
3021- F : Fn ( T :: Native ) -> T :: Native ,
3022- {
3023- if let Some ( d) = array. as_any_dictionary_opt ( ) {
3024- let new_values = unary_dyn :: < F , T > ( d. values ( ) , op) ?;
3025- return Ok ( Arc :: new ( d. with_values ( Arc :: new ( new_values) ) ) ) ;
3026- }
3027-
3028- match array. as_primitive_opt :: < T > ( ) {
3029- Some ( a) if PrimitiveArray :: < T > :: is_compatible ( a. data_type ( ) ) => {
3030- Ok ( Arc :: new ( unary :: < T , F , T > (
3031- array. as_any ( ) . downcast_ref :: < PrimitiveArray < T > > ( ) . unwrap ( ) ,
3032- op,
3033- ) ) )
3034- }
3035- _ => Err ( ArrowError :: NotYetImplemented ( format ! (
3036- "Cannot perform unary operation of type {} on array of type {}" ,
3037- T :: DATA_TYPE ,
3038- array. data_type( )
3039- ) ) ) ,
3040- }
3041- }
3042-
3043- /// Remove any trailing zeroes in the string if they occur after in the fractional seconds,
3044- /// to match Spark behavior
3045- /// example:
3046- /// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9"
3047- /// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99"
3048- /// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999"
3049- /// "1970-01-01 05:30:00" => "1970-01-01 05:30:00"
3050- /// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001"
3051- fn remove_trailing_zeroes ( array : ArrayRef ) -> ArrayRef {
3052- let string_array = as_generic_string_array :: < i32 > ( & array) . unwrap ( ) ;
3053- let result = string_array
3054- . iter ( )
3055- . map ( |s| s. map ( trim_end) )
3056- . collect :: < GenericStringArray < i32 > > ( ) ;
3057- Arc :: new ( result) as ArrayRef
3058- }
3059-
3060- fn trim_end ( s : & str ) -> & str {
3061- if s. rfind ( '.' ) . is_some ( ) {
3062- s. trim_end_matches ( '0' )
3063- } else {
3064- s
3065- }
3066- }
3067-
30682970#[ cfg( test) ]
30692971mod tests {
30702972 use arrow:: array:: StringArray ;
0 commit comments