1616// under the License.
1717
1818use crate :: conversion_funcs:: boolean:: can_cast_from_boolean;
19- use crate :: conversion_funcs:: utils:: spark_cast_postprocess;
19+ use crate :: conversion_funcs:: utils:: { is_identity_cast , spark_cast_postprocess} ;
2020use crate :: utils:: array_with_timezone;
2121use crate :: { timezone, BinaryOutputStyle } ;
2222use crate :: { EvalMode , SparkError , SparkResult } ;
@@ -176,34 +176,31 @@ pub fn cast_supported(
176176 to_type
177177 } ;
178178
179- if from_type == to_type {
180- return true ;
181- }
182-
183- match ( from_type, to_type) {
184- ( Boolean , _) => can_cast_from_boolean ( to_type) ,
185- ( UInt8 | UInt16 | UInt32 | UInt64 , Int8 | Int16 | Int32 | Int64 )
186- if options. allow_cast_unsigned_ints =>
187- {
188- true
179+ is_identity_cast ( from_type, to_type)
180+ || match ( from_type, to_type) {
181+ ( Boolean , _) => can_cast_from_boolean ( to_type) ,
182+ ( UInt8 | UInt16 | UInt32 | UInt64 , Int8 | Int16 | Int32 | Int64 )
183+ if options. allow_cast_unsigned_ints =>
184+ {
185+ true
186+ }
187+ ( Int8 , _) => can_cast_from_byte ( to_type, options) ,
188+ ( Int16 , _) => can_cast_from_short ( to_type, options) ,
189+ ( Int32 , _) => can_cast_from_int ( to_type, options) ,
190+ ( Int64 , _) => can_cast_from_long ( to_type, options) ,
191+ ( Float32 , _) => can_cast_from_float ( to_type, options) ,
192+ ( Float64 , _) => can_cast_from_double ( to_type, options) ,
193+ ( Decimal128 ( p, s) , _) => can_cast_from_decimal ( p, s, to_type, options) ,
194+ ( Timestamp ( _, None ) , _) => can_cast_from_timestamp_ntz ( to_type, options) ,
195+ ( Timestamp ( _, Some ( _) ) , _) => can_cast_from_timestamp ( to_type, options) ,
196+ ( Utf8 | LargeUtf8 , _) => can_cast_from_string ( to_type, options) ,
197+ ( _, Utf8 | LargeUtf8 ) => can_cast_to_string ( from_type, options) ,
198+ ( Struct ( from_fields) , Struct ( to_fields) ) => from_fields
199+ . iter ( )
200+ . zip ( to_fields. iter ( ) )
201+ . all ( |( a, b) | cast_supported ( a. data_type ( ) , b. data_type ( ) , options) ) ,
202+ _ => false ,
189203 }
190- ( Int8 , _) => can_cast_from_byte ( to_type, options) ,
191- ( Int16 , _) => can_cast_from_short ( to_type, options) ,
192- ( Int32 , _) => can_cast_from_int ( to_type, options) ,
193- ( Int64 , _) => can_cast_from_long ( to_type, options) ,
194- ( Float32 , _) => can_cast_from_float ( to_type, options) ,
195- ( Float64 , _) => can_cast_from_double ( to_type, options) ,
196- ( Decimal128 ( p, s) , _) => can_cast_from_decimal ( p, s, to_type, options) ,
197- ( Timestamp ( _, None ) , _) => can_cast_from_timestamp_ntz ( to_type, options) ,
198- ( Timestamp ( _, Some ( _) ) , _) => can_cast_from_timestamp ( to_type, options) ,
199- ( Utf8 | LargeUtf8 , _) => can_cast_from_string ( to_type, options) ,
200- ( _, Utf8 | LargeUtf8 ) => can_cast_to_string ( from_type, options) ,
201- ( Struct ( from_fields) , Struct ( to_fields) ) => from_fields
202- . iter ( )
203- . zip ( to_fields. iter ( ) )
204- . all ( |( a, b) | cast_supported ( a. data_type ( ) , b. data_type ( ) , options) ) ,
205- _ => false ,
206- }
207204}
208205
209206fn can_cast_from_string ( to_type : & DataType , options : & SparkCastOptions ) -> bool {
@@ -947,7 +944,7 @@ fn dict_from_values<K: ArrowDictionaryKeyType>(
947944 Ok ( Arc :: new ( dict_array) )
948945}
949946
950- fn cast_array (
947+ pub fn cast_array (
951948 array : ArrayRef ,
952949 to_type : & DataType ,
953950 cast_options : & SparkCastOptions ,
@@ -1303,16 +1300,26 @@ fn cast_binary_formatter(value: &[u8]) -> String {
13031300/// Determines if DataFusion supports the given cast in a way that is
13041301/// compatible with Spark
13051302fn is_datafusion_spark_compatible ( from_type : & DataType , to_type : & DataType ) -> bool {
1306- if from_type == to_type {
1307- return true ;
1308- }
1309- match from_type {
1310- DataType :: Null => {
1311- matches ! ( to_type, DataType :: List ( _) )
1312- }
1313- DataType :: Boolean => can_cast_from_boolean ( to_type) ,
1314- DataType :: Int8 | DataType :: Int16 | DataType :: Int32 | DataType :: Int64 => {
1315- matches ! (
1303+ is_identity_cast ( from_type, to_type)
1304+ || match from_type {
1305+ DataType :: Null => {
1306+ matches ! ( to_type, DataType :: List ( _) )
1307+ }
1308+ DataType :: Boolean => can_cast_from_boolean ( to_type) ,
1309+ DataType :: Int8 | DataType :: Int16 | DataType :: Int32 | DataType :: Int64 => {
1310+ matches ! (
1311+ to_type,
1312+ DataType :: Boolean
1313+ | DataType :: Int8
1314+ | DataType :: Int16
1315+ | DataType :: Int32
1316+ | DataType :: Int64
1317+ | DataType :: Float32
1318+ | DataType :: Float64
1319+ | DataType :: Utf8
1320+ )
1321+ }
1322+ DataType :: Float32 | DataType :: Float64 => matches ! (
13161323 to_type,
13171324 DataType :: Boolean
13181325 | DataType :: Int8
@@ -1321,46 +1328,34 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b
13211328 | DataType :: Int64
13221329 | DataType :: Float32
13231330 | DataType :: Float64
1324- | DataType :: Utf8
1325- )
1326- }
1327- DataType :: Float32 | DataType :: Float64 => matches ! (
1328- to_type,
1329- DataType :: Boolean
1330- | DataType :: Int8
1331- | DataType :: Int16
1332- | DataType :: Int32
1333- | DataType :: Int64
1334- | DataType :: Float32
1335- | DataType :: Float64
1336- ) ,
1337- DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) => matches ! (
1338- to_type,
1339- DataType :: Int8
1340- | DataType :: Int16
1341- | DataType :: Int32
1342- | DataType :: Int64
1343- | DataType :: Float32
1344- | DataType :: Float64
1345- | DataType :: Decimal128 ( _, _)
1346- | DataType :: Decimal256 ( _, _)
1347- | DataType :: Utf8 // note that there can be formatting differences
1348- ) ,
1349- DataType :: Utf8 => matches ! ( to_type, DataType :: Binary ) ,
1350- DataType :: Date32 => matches ! ( to_type, DataType :: Int32 | DataType :: Utf8 ) ,
1351- DataType :: Timestamp ( _, _) => {
1352- matches ! (
1331+ ) ,
1332+ DataType :: Decimal128 ( _, _) | DataType :: Decimal256 ( _, _) => matches ! (
13531333 to_type,
1354- DataType :: Int64 | DataType :: Date32 | DataType :: Utf8 | DataType :: Timestamp ( _, _)
1355- )
1356- }
1357- DataType :: Binary => {
1358- // note that this is not completely Spark compatible because
1359- // DataFusion only supports binary data containing valid UTF-8 strings
1360- matches ! ( to_type, DataType :: Utf8 )
1334+ DataType :: Int8
1335+ | DataType :: Int16
1336+ | DataType :: Int32
1337+ | DataType :: Int64
1338+ | DataType :: Float32
1339+ | DataType :: Float64
1340+ | DataType :: Decimal128 ( _, _)
1341+ | DataType :: Decimal256 ( _, _)
1342+ | DataType :: Utf8 // note that there can be formatting differences
1343+ ) ,
1344+ DataType :: Utf8 => matches ! ( to_type, DataType :: Binary ) ,
1345+ DataType :: Date32 => matches ! ( to_type, DataType :: Int32 | DataType :: Utf8 ) ,
1346+ DataType :: Timestamp ( _, _) => {
1347+ matches ! (
1348+ to_type,
1349+ DataType :: Int64 | DataType :: Date32 | DataType :: Utf8 | DataType :: Timestamp ( _, _)
1350+ )
1351+ }
1352+ DataType :: Binary => {
1353+ // note that this is not completely Spark compatible because
1354+ // DataFusion only supports binary data containing valid UTF-8 strings
1355+ matches ! ( to_type, DataType :: Utf8 )
1356+ }
1357+ _ => false ,
13611358 }
1362- _ => false ,
1363- }
13641359}
13651360
13661361/// Cast between struct types based on logic in
0 commit comments