Skip to content

Commit 70ec5a7

Browse files
author
B Vadlamani
committed
refactor_boolean_cast_ops_add_tests
1 parent b10e40b commit 70ec5a7

3 files changed

Lines changed: 209 additions & 78 deletions

File tree

native/spark-expr/src/conversion_funcs/boolean.rs

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,137 @@ pub fn can_cast_from_boolean(to_type: &DataType) -> bool {
2121
use DataType::*;
2222
matches!(
2323
to_type,
24-
Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8
24+
Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8
2525
)
2626
}
27+
28+
#[cfg(test)]
29+
mod tests {
30+
use super::*;
31+
use crate::cast::cast_array;
32+
use crate::{EvalMode, SparkCastOptions};
33+
use arrow::array::{
34+
Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
35+
Int64Array, Int8Array, StringArray,
36+
};
37+
use std::sync::Arc;
38+
39+
fn test_input_bool_array() -> ArrayRef {
40+
Arc::new(BooleanArray::from(vec![Some(true), Some(false), None]))
41+
}
42+
43+
fn test_input_spark_opts() -> SparkCastOptions {
44+
SparkCastOptions::new(EvalMode::Legacy, "Asia/Kolkata", false)
45+
}
46+
47+
#[test]
48+
fn test_can_cast_from_boolean() {
49+
assert!(can_cast_from_boolean(&DataType::Boolean));
50+
assert!(can_cast_from_boolean(&DataType::Int8));
51+
assert!(can_cast_from_boolean(&DataType::Int16));
52+
assert!(can_cast_from_boolean(&DataType::Int32));
53+
assert!(can_cast_from_boolean(&DataType::Int64));
54+
assert!(can_cast_from_boolean(&DataType::Float32));
55+
assert!(can_cast_from_boolean(&DataType::Float64));
56+
assert!(can_cast_from_boolean(&DataType::Utf8));
57+
assert!(!can_cast_from_boolean(&DataType::Null));
58+
}
59+
60+
#[test]
61+
fn test_bool_to_int8_cast() {
62+
let result = cast_array(
63+
test_input_bool_array(),
64+
&DataType::Int8,
65+
&test_input_spark_opts(),
66+
)
67+
.unwrap();
68+
let arr = result.as_any().downcast_ref::<Int8Array>().unwrap();
69+
assert_eq!(arr.value(0), 1);
70+
assert_eq!(arr.value(1), 0);
71+
assert!(arr.is_null(2));
72+
}
73+
74+
#[test]
75+
fn test_bool_to_int16_cast() {
76+
let result = cast_array(
77+
test_input_bool_array(),
78+
&DataType::Int16,
79+
&test_input_spark_opts(),
80+
)
81+
.unwrap();
82+
let arr = result.as_any().downcast_ref::<Int16Array>().unwrap();
83+
assert_eq!(arr.value(0), 1);
84+
assert_eq!(arr.value(1), 0);
85+
assert!(arr.is_null(2));
86+
}
87+
88+
#[test]
89+
fn test_bool_to_int32_cast() {
90+
let result = cast_array(
91+
test_input_bool_array(),
92+
&DataType::Int32,
93+
&test_input_spark_opts(),
94+
)
95+
.unwrap();
96+
let arr = result.as_any().downcast_ref::<Int32Array>().unwrap();
97+
assert_eq!(arr.value(0), 1);
98+
assert_eq!(arr.value(1), 0);
99+
assert!(arr.is_null(2));
100+
}
101+
102+
#[test]
103+
fn test_bool_to_int64_cast() {
104+
let result = cast_array(
105+
test_input_bool_array(),
106+
&DataType::Int64,
107+
&test_input_spark_opts(),
108+
)
109+
.unwrap();
110+
let arr = result.as_any().downcast_ref::<Int64Array>().unwrap();
111+
assert_eq!(arr.value(0), 1);
112+
assert_eq!(arr.value(1), 0);
113+
assert!(arr.is_null(2));
114+
}
115+
116+
#[test]
117+
fn test_bool_to_float32_cast() {
118+
let result = cast_array(
119+
test_input_bool_array(),
120+
&DataType::Float32,
121+
&test_input_spark_opts(),
122+
)
123+
.unwrap();
124+
let arr = result.as_any().downcast_ref::<Float32Array>().unwrap();
125+
assert_eq!(arr.value(0), 1.0);
126+
assert_eq!(arr.value(1), 0.0);
127+
assert!(arr.is_null(2));
128+
}
129+
130+
#[test]
131+
fn test_bool_to_float64_cast() {
132+
let result = cast_array(
133+
test_input_bool_array(),
134+
&DataType::Float64,
135+
&test_input_spark_opts(),
136+
)
137+
.unwrap();
138+
let arr = result.as_any().downcast_ref::<Float64Array>().unwrap();
139+
assert_eq!(arr.value(0), 1.0);
140+
assert_eq!(arr.value(1), 0.0);
141+
assert!(arr.is_null(2));
142+
}
143+
144+
#[test]
145+
fn test_bool_to_string_cast() {
146+
let result = cast_array(
147+
test_input_bool_array(),
148+
&DataType::Utf8,
149+
&test_input_spark_opts(),
150+
)
151+
.unwrap();
152+
let arr = result.as_any().downcast_ref::<StringArray>().unwrap();
153+
assert_eq!(arr.value(0), "true");
154+
assert_eq!(arr.value(1), "false");
155+
assert!(arr.is_null(2));
156+
}
157+
}

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 72 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use crate::conversion_funcs::boolean::can_cast_from_boolean;
19-
use crate::conversion_funcs::utils::spark_cast_postprocess;
19+
use crate::conversion_funcs::utils::{is_identity_cast, spark_cast_postprocess};
2020
use crate::utils::array_with_timezone;
2121
use crate::{timezone, BinaryOutputStyle};
2222
use crate::{EvalMode, SparkError, SparkResult};
@@ -176,34 +176,31 @@ pub fn cast_supported(
176176
to_type
177177
};
178178

179-
if from_type == to_type {
180-
return true;
181-
}
182-
183-
match (from_type, to_type) {
184-
(Boolean, _) => can_cast_from_boolean(to_type),
185-
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
186-
if options.allow_cast_unsigned_ints =>
187-
{
188-
true
179+
is_identity_cast(from_type, to_type)
180+
|| match (from_type, to_type) {
181+
(Boolean, _) => can_cast_from_boolean(to_type),
182+
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
183+
if options.allow_cast_unsigned_ints =>
184+
{
185+
true
186+
}
187+
(Int8, _) => can_cast_from_byte(to_type, options),
188+
(Int16, _) => can_cast_from_short(to_type, options),
189+
(Int32, _) => can_cast_from_int(to_type, options),
190+
(Int64, _) => can_cast_from_long(to_type, options),
191+
(Float32, _) => can_cast_from_float(to_type, options),
192+
(Float64, _) => can_cast_from_double(to_type, options),
193+
(Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options),
194+
(Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, options),
195+
(Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, options),
196+
(Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
197+
(_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
198+
(Struct(from_fields), Struct(to_fields)) => from_fields
199+
.iter()
200+
.zip(to_fields.iter())
201+
.all(|(a, b)| cast_supported(a.data_type(), b.data_type(), options)),
202+
_ => false,
189203
}
190-
(Int8, _) => can_cast_from_byte(to_type, options),
191-
(Int16, _) => can_cast_from_short(to_type, options),
192-
(Int32, _) => can_cast_from_int(to_type, options),
193-
(Int64, _) => can_cast_from_long(to_type, options),
194-
(Float32, _) => can_cast_from_float(to_type, options),
195-
(Float64, _) => can_cast_from_double(to_type, options),
196-
(Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options),
197-
(Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, options),
198-
(Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, options),
199-
(Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
200-
(_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
201-
(Struct(from_fields), Struct(to_fields)) => from_fields
202-
.iter()
203-
.zip(to_fields.iter())
204-
.all(|(a, b)| cast_supported(a.data_type(), b.data_type(), options)),
205-
_ => false,
206-
}
207204
}
208205

209206
fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool {
@@ -947,7 +944,7 @@ fn dict_from_values<K: ArrowDictionaryKeyType>(
947944
Ok(Arc::new(dict_array))
948945
}
949946

950-
fn cast_array(
947+
pub fn cast_array(
951948
array: ArrayRef,
952949
to_type: &DataType,
953950
cast_options: &SparkCastOptions,
@@ -1303,16 +1300,26 @@ fn cast_binary_formatter(value: &[u8]) -> String {
13031300
/// Determines if DataFusion supports the given cast in a way that is
13041301
/// compatible with Spark
13051302
fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool {
1306-
if from_type == to_type {
1307-
return true;
1308-
}
1309-
match from_type {
1310-
DataType::Null => {
1311-
matches!(to_type, DataType::List(_))
1312-
}
1313-
DataType::Boolean => can_cast_from_boolean(to_type),
1314-
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1315-
matches!(
1303+
is_identity_cast(from_type, to_type)
1304+
|| match from_type {
1305+
DataType::Null => {
1306+
matches!(to_type, DataType::List(_))
1307+
}
1308+
DataType::Boolean => can_cast_from_boolean(to_type),
1309+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1310+
matches!(
1311+
to_type,
1312+
DataType::Boolean
1313+
| DataType::Int8
1314+
| DataType::Int16
1315+
| DataType::Int32
1316+
| DataType::Int64
1317+
| DataType::Float32
1318+
| DataType::Float64
1319+
| DataType::Utf8
1320+
)
1321+
}
1322+
DataType::Float32 | DataType::Float64 => matches!(
13161323
to_type,
13171324
DataType::Boolean
13181325
| DataType::Int8
@@ -1321,46 +1328,34 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b
13211328
| DataType::Int64
13221329
| DataType::Float32
13231330
| DataType::Float64
1324-
| DataType::Utf8
1325-
)
1326-
}
1327-
DataType::Float32 | DataType::Float64 => matches!(
1328-
to_type,
1329-
DataType::Boolean
1330-
| DataType::Int8
1331-
| DataType::Int16
1332-
| DataType::Int32
1333-
| DataType::Int64
1334-
| DataType::Float32
1335-
| DataType::Float64
1336-
),
1337-
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
1338-
to_type,
1339-
DataType::Int8
1340-
| DataType::Int16
1341-
| DataType::Int32
1342-
| DataType::Int64
1343-
| DataType::Float32
1344-
| DataType::Float64
1345-
| DataType::Decimal128(_, _)
1346-
| DataType::Decimal256(_, _)
1347-
| DataType::Utf8 // note that there can be formatting differences
1348-
),
1349-
DataType::Utf8 => matches!(to_type, DataType::Binary),
1350-
DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8),
1351-
DataType::Timestamp(_, _) => {
1352-
matches!(
1331+
),
1332+
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
13531333
to_type,
1354-
DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
1355-
)
1356-
}
1357-
DataType::Binary => {
1358-
// note that this is not completely Spark compatible because
1359-
// DataFusion only supports binary data containing valid UTF-8 strings
1360-
matches!(to_type, DataType::Utf8)
1334+
DataType::Int8
1335+
| DataType::Int16
1336+
| DataType::Int32
1337+
| DataType::Int64
1338+
| DataType::Float32
1339+
| DataType::Float64
1340+
| DataType::Decimal128(_, _)
1341+
| DataType::Decimal256(_, _)
1342+
| DataType::Utf8 // note that there can be formatting differences
1343+
),
1344+
DataType::Utf8 => matches!(to_type, DataType::Binary),
1345+
DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8),
1346+
DataType::Timestamp(_, _) => {
1347+
matches!(
1348+
to_type,
1349+
DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
1350+
)
1351+
}
1352+
DataType::Binary => {
1353+
// note that this is not completely Spark compatible because
1354+
// DataFusion only supports binary data containing valid UTF-8 strings
1355+
matches!(to_type, DataType::Utf8)
1356+
}
1357+
_ => false,
13611358
}
1362-
_ => false,
1363-
}
13641359
}
13651360

13661361
/// Cast between struct types based on logic in

native/spark-expr/src/conversion_funcs/utils.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,8 @@ fn trim_end(s: &str) -> &str {
107107
s
108108
}
109109
}
110+
111+
#[inline]
112+
pub fn is_identity_cast(from_type: &DataType, to_type: &DataType) -> bool {
113+
from_type == to_type
114+
}

0 commit comments

Comments
 (0)