Skip to content

Commit 05ece88

Browse files
committed
refactor_bool_cast_module
1 parent 668f7a3 commit 05ece88

3 files changed

Lines changed: 51 additions & 48 deletions

File tree

native/spark-expr/benches/cast_from_boolean.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ fn criterion_benchmark(c: &mut Criterion) {
2727
let expr = Arc::new(Column::new("a", 0));
2828
let boolean_batch = create_boolean_batch();
2929
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false);
30-
Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
3130
let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
3231
let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
3332
let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ use std::{
6969

7070
static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
7171

72+
pub(crate) const MICROS_PER_SECOND: i64 = 1000000;
73+
7274
static CAST_OPTIONS: CastOptions = CastOptions {
7375
safe: true,
7476
format_options: FormatOptions::new()
@@ -1166,26 +1168,16 @@ fn cast_binary_formatter(value: &[u8]) -> String {
11661168
/// Determines if DataFusion supports the given cast in a way that is
11671169
/// compatible with Spark
11681170
fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool {
1169-
from_type == to_type
1170-
|| match from_type {
1171-
DataType::Null => {
1172-
matches!(to_type, DataType::List(_))
1173-
}
1174-
DataType::Boolean => is_df_cast_from_bool_spark_compatible(to_type),
1175-
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1176-
matches!(
1177-
to_type,
1178-
DataType::Boolean
1179-
| DataType::Int8
1180-
| DataType::Int16
1181-
| DataType::Int32
1182-
| DataType::Int64
1183-
| DataType::Float32
1184-
| DataType::Float64
1185-
| DataType::Utf8
1186-
)
1187-
}
1188-
DataType::Float32 | DataType::Float64 => matches!(
1171+
if from_type == to_type {
1172+
return true;
1173+
}
1174+
match from_type {
1175+
DataType::Null => {
1176+
matches!(to_type, DataType::List(_))
1177+
}
1178+
DataType::Boolean => is_df_cast_from_bool_spark_compatible(to_type),
1179+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
1180+
matches!(
11891181
to_type,
11901182
DataType::Boolean
11911183
| DataType::Int8
@@ -1194,34 +1186,46 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b
11941186
| DataType::Int64
11951187
| DataType::Float32
11961188
| DataType::Float64
1197-
),
1198-
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
1189+
| DataType::Utf8
1190+
)
1191+
}
1192+
DataType::Float32 | DataType::Float64 => matches!(
1193+
to_type,
1194+
DataType::Boolean
1195+
| DataType::Int8
1196+
| DataType::Int16
1197+
| DataType::Int32
1198+
| DataType::Int64
1199+
| DataType::Float32
1200+
| DataType::Float64
1201+
),
1202+
DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!(
1203+
to_type,
1204+
DataType::Int8
1205+
| DataType::Int16
1206+
| DataType::Int32
1207+
| DataType::Int64
1208+
| DataType::Float32
1209+
| DataType::Float64
1210+
| DataType::Decimal128(_, _)
1211+
| DataType::Decimal256(_, _)
1212+
| DataType::Utf8 // note that there can be formatting differences
1213+
),
1214+
DataType::Utf8 => matches!(to_type, DataType::Binary),
1215+
DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8),
1216+
DataType::Timestamp(_, _) => {
1217+
matches!(
11991218
to_type,
1200-
DataType::Int8
1201-
| DataType::Int16
1202-
| DataType::Int32
1203-
| DataType::Int64
1204-
| DataType::Float32
1205-
| DataType::Float64
1206-
| DataType::Decimal128(_, _)
1207-
| DataType::Decimal256(_, _)
1208-
| DataType::Utf8 // note that there can be formatting differences
1209-
),
1210-
DataType::Utf8 => matches!(to_type, DataType::Binary),
1211-
DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8),
1212-
DataType::Timestamp(_, _) => {
1213-
matches!(
1214-
to_type,
1215-
DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
1216-
)
1217-
}
1218-
DataType::Binary => {
1219-
// note that this is not completely Spark compatible because
1220-
// DataFusion only supports binary data containing valid UTF-8 strings
1221-
matches!(to_type, DataType::Utf8)
1222-
}
1223-
_ => false,
1219+
DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
1220+
)
1221+
}
1222+
DataType::Binary => {
1223+
// note that this is not completely Spark compatible because
1224+
// DataFusion only supports binary data containing valid UTF-8 strings
1225+
matches!(to_type, DataType::Utf8)
12241226
}
1227+
_ => false,
1228+
}
12251229
}
12261230

12271231
/// Cast between struct types based on logic in

native/spark-expr/src/conversion_funcs/utils.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::cast::MICROS_PER_SECOND;
1819
use crate::SparkError;
1920
use arrow::array::{
2021
Array, ArrayRef, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray,
@@ -26,7 +27,6 @@ use datafusion::common::cast::as_generic_string_array;
2627
use num::integer::div_floor;
2728
use std::sync::Arc;
2829

29-
const MICROS_PER_SECOND: i64 = 1000000;
3030
/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated
3131
pub fn unary_dyn<F, T>(array: &ArrayRef, op: F) -> Result<ArrayRef, ArrowError>
3232
where

0 commit comments

Comments
 (0)