From af692956ba61d33753515370d687e6ea74e010ff Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 10 Feb 2026 00:19:44 -0800 Subject: [PATCH 01/11] refactor_boolean_cast_ops --- .../src/conversion_funcs/boolean.rs | 26 +++++ .../spark-expr/src/conversion_funcs/cast.rs | 106 +---------------- native/spark-expr/src/conversion_funcs/mod.rs | 2 + .../spark-expr/src/conversion_funcs/utils.rs | 109 ++++++++++++++++++ 4 files changed, 143 insertions(+), 100 deletions(-) create mode 100644 native/spark-expr/src/conversion_funcs/boolean.rs create mode 100644 native/spark-expr/src/conversion_funcs/utils.rs diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs b/native/spark-expr/src/conversion_funcs/boolean.rs new file mode 100644 index 0000000000..f026af47e1 --- /dev/null +++ b/native/spark-expr/src/conversion_funcs/boolean.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; + +pub fn can_cast_from_boolean(to_type: &DataType) -> bool { + use DataType::*; + matches!( + to_type, + Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 + ) +} diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 2809104f26..c7e446ea1e 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use crate::conversion_funcs::boolean::can_cast_from_boolean; +use crate::conversion_funcs::utils::spark_cast_postprocess; use crate::utils::array_with_timezone; use crate::EvalMode::Legacy; use crate::{timezone, BinaryOutputStyle}; @@ -37,7 +39,7 @@ use arrow::{ GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, }, - compute::{cast_with_options, take, unary, CastOptions}, + compute::{cast_with_options, take, CastOptions}, datatypes::{ is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, Float32Type, Float64Type, Int64Type, TimestampMicrosecondType, @@ -48,16 +50,10 @@ use arrow::{ }; use base64::prelude::*; use chrono::{DateTime, NaiveDate, TimeZone, Timelike}; -use datafusion::common::{ - cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult, - ScalarValue, -}; +use datafusion::common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ColumnarValue; -use num::{ - cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, ToPrimitive, - Zero, -}; +use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, ToPrimitive, Zero}; use regex::Regex; use std::str::FromStr; use std::{ @@ -70,8 +66,6 @@ use std::{ static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); -const MICROS_PER_SECOND: i64 = 1000000; - static CAST_OPTIONS: CastOptions = CastOptions { safe: true, format_options: FormatOptions::new() @@ -162,7 +156,6 @@ impl Hash for Cast { self.cast_options.hash(state); } } - macro_rules! cast_utf8_to_int { ($array:expr, $array_type:ty, $parse_fn:expr) => {{ let len = $array.len(); @@ -1145,16 +1138,7 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b DataType::Null => { matches!(to_type, DataType::List(_)) } - DataType::Boolean => matches!( - to_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Utf8 - ), + DataType::Boolean => can_cast_from_boolean(to_type), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { matches!( to_type, @@ -2811,84 +2795,6 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult> } } -/// This takes for special casting cases of Spark. E.g., Timestamp to Long. -/// This function runs as a post process of the DataFusion cast(). By the time it arrives here, -/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify -/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in -/// expressions/cast.rs, so it can be still Dictionary. -fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef { - match (from_type, to_type) { - (DataType::Timestamp(_, _), DataType::Int64) => { - // See Spark's `Cast` expression - unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() - } - (DataType::Dictionary(_, value_type), DataType::Int64) - if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => - { - // See Spark's `Cast` expression - unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() - } - (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array), - (DataType::Dictionary(_, value_type), DataType::Utf8) - if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => - { - remove_trailing_zeroes(array) - } - _ => array, - } -} - -/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated -fn unary_dyn(array: &ArrayRef, op: F) -> Result -where - T: ArrowPrimitiveType, - F: Fn(T::Native) -> T::Native, -{ - if let Some(d) = array.as_any_dictionary_opt() { - let new_values = unary_dyn::(d.values(), op)?; - return Ok(Arc::new(d.with_values(Arc::new(new_values)))); - } - - match array.as_primitive_opt::() { - Some(a) if PrimitiveArray::::is_compatible(a.data_type()) => { - Ok(Arc::new(unary::( - array.as_any().downcast_ref::>().unwrap(), - op, - ))) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Cannot perform unary operation of type {} on array of type {}", - T::DATA_TYPE, - array.data_type() - ))), - } -} - -/// Remove any trailing zeroes in the string if they occur after in the fractional seconds, -/// to match Spark behavior -/// example: -/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9" -/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99" -/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999" -/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00" -/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001" -fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef { - let string_array = as_generic_string_array::(&array).unwrap(); - let result = string_array - .iter() - .map(|s| s.map(trim_end)) - .collect::>(); - Arc::new(result) as ArrayRef -} - -fn trim_end(s: &str) -> &str { - if s.rfind('.').is_some() { - s.trim_end_matches('0') - } else { - s - } -} - #[cfg(test)] mod tests { use arrow::array::StringArray; diff --git a/native/spark-expr/src/conversion_funcs/mod.rs b/native/spark-expr/src/conversion_funcs/mod.rs index f2c6f7ca36..190c115204 100644 --- a/native/spark-expr/src/conversion_funcs/mod.rs +++ b/native/spark-expr/src/conversion_funcs/mod.rs @@ -15,4 +15,6 @@ // specific language governing permissions and limitations // under the License. +mod boolean; pub mod cast; +mod utils; diff --git a/native/spark-expr/src/conversion_funcs/utils.rs b/native/spark-expr/src/conversion_funcs/utils.rs new file mode 100644 index 0000000000..4abb45f725 --- /dev/null +++ b/native/spark-expr/src/conversion_funcs/utils.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray, +}; +use arrow::compute::unary; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::error::ArrowError; +use datafusion::common::cast::as_generic_string_array; +use num::integer::div_floor; +use std::sync::Arc; + +const MICROS_PER_SECOND: i64 = 1000000; +/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated +pub fn unary_dyn(array: &ArrayRef, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + if let Some(d) = array.as_any_dictionary_opt() { + let new_values = unary_dyn::(d.values(), op)?; + return Ok(Arc::new(d.with_values(Arc::new(new_values)))); + } + + match array.as_primitive_opt::() { + Some(a) if PrimitiveArray::::is_compatible(a.data_type()) => { + Ok(Arc::new(unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + ))) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, + array.data_type() + ))), + } +} + +/// This takes for special casting cases of Spark. E.g., Timestamp to Long. +/// This function runs as a post process of the DataFusion cast(). By the time it arrives here, +/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify +/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in +/// expressions/cast.rs, so it can be still Dictionary. +pub fn spark_cast_postprocess( + array: ArrayRef, + from_type: &DataType, + to_type: &DataType, +) -> ArrayRef { + match (from_type, to_type) { + (DataType::Timestamp(_, _), DataType::Int64) => { + // See Spark's `Cast` expression + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() + } + (DataType::Dictionary(_, value_type), DataType::Int64) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + // See Spark's `Cast` expression + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() + } + (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array), + (DataType::Dictionary(_, value_type), DataType::Utf8) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + remove_trailing_zeroes(array) + } + _ => array, + } +} + +/// Remove any trailing zeroes in the string if they occur after in the fractional seconds, +/// to match Spark behavior +/// example: +/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9" +/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99" +/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999" +/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00" +/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001" +fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef { + let string_array = as_generic_string_array::(&array).unwrap(); + let result = string_array + .iter() + .map(|s| s.map(trim_end)) + .collect::>(); + Arc::new(result) as ArrayRef +} + +fn trim_end(s: &str) -> &str { + if s.rfind('.').is_some() { + s.trim_end_matches('0') + } else { + s + } +} From a488155ecd6b6f57ea4aa82cbad98add01ce8ae1 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 10 Feb 2026 11:33:39 -0800 Subject: [PATCH 02/11] refactor_boolean_cast_ops_add_tests --- .../src/conversion_funcs/boolean.rs | 133 +++++++++++++++++- .../spark-expr/src/conversion_funcs/cast.rs | 99 +++++++------ .../spark-expr/src/conversion_funcs/utils.rs | 5 + 3 files changed, 186 insertions(+), 51 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs b/native/spark-expr/src/conversion_funcs/boolean.rs index f026af47e1..b46d8667b1 100644 --- a/native/spark-expr/src/conversion_funcs/boolean.rs +++ b/native/spark-expr/src/conversion_funcs/boolean.rs @@ -21,6 +21,137 @@ pub fn can_cast_from_boolean(to_type: &DataType) -> bool { use DataType::*; matches!( to_type, - Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 ) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::cast::cast_array; + use crate::{EvalMode, SparkCastOptions}; + use arrow::array::{ + Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, StringArray, + }; + use std::sync::Arc; + + fn test_input_bool_array() -> ArrayRef { + Arc::new(BooleanArray::from(vec![Some(true), Some(false), None])) + } + + fn test_input_spark_opts() -> SparkCastOptions { + SparkCastOptions::new(EvalMode::Legacy, "Asia/Kolkata", false) + } + + #[test] + fn test_can_cast_from_boolean() { + assert!(can_cast_from_boolean(&DataType::Boolean)); + assert!(can_cast_from_boolean(&DataType::Int8)); + assert!(can_cast_from_boolean(&DataType::Int16)); + assert!(can_cast_from_boolean(&DataType::Int32)); + assert!(can_cast_from_boolean(&DataType::Int64)); + assert!(can_cast_from_boolean(&DataType::Float32)); + assert!(can_cast_from_boolean(&DataType::Float64)); + assert!(can_cast_from_boolean(&DataType::Utf8)); + assert!(!can_cast_from_boolean(&DataType::Null)); + } + + #[test] + fn test_bool_to_int8_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Int8, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_int16_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Int16, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_int32_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Int32, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_int64_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Int64, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1); + assert_eq!(arr.value(1), 0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_float32_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Float32, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1.0); + assert_eq!(arr.value(1), 0.0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_float64_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Float64, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), 1.0); + assert_eq!(arr.value(1), 0.0); + assert!(arr.is_null(2)); + } + + #[test] + fn test_bool_to_string_cast() { + let result = cast_array( + test_input_bool_array(), + &DataType::Utf8, + &test_input_spark_opts(), + ) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "true"); + assert_eq!(arr.value(1), "false"); + assert!(arr.is_null(2)); + } +} diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index c7e446ea1e..fc8b31b1a7 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -16,7 +16,7 @@ // under the License. use crate::conversion_funcs::boolean::can_cast_from_boolean; -use crate::conversion_funcs::utils::spark_cast_postprocess; +use crate::conversion_funcs::utils::{is_identity_cast, spark_cast_postprocess}; use crate::utils::array_with_timezone; use crate::EvalMode::Legacy; use crate::{timezone, BinaryOutputStyle}; @@ -156,6 +156,7 @@ impl Hash for Cast { self.cast_options.hash(state); } } + macro_rules! cast_utf8_to_int { ($array:expr, $array_type:ty, $parse_fn:expr) => {{ let len = $array.len(); @@ -752,7 +753,7 @@ fn dict_from_values( Ok(Arc::new(dict_array)) } -fn cast_array( +pub fn cast_array( array: ArrayRef, to_type: &DataType, cast_options: &SparkCastOptions, @@ -1131,16 +1132,26 @@ fn cast_binary_formatter(value: &[u8]) -> String { /// Determines if DataFusion supports the given cast in a way that is /// compatible with Spark fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { - if from_type == to_type { - return true; - } - match from_type { - DataType::Null => { - matches!(to_type, DataType::List(_)) - } - DataType::Boolean => can_cast_from_boolean(to_type), - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - matches!( + is_identity_cast(from_type, to_type) + || match from_type { + DataType::Null => { + matches!(to_type, DataType::List(_)) + } + DataType::Boolean => can_cast_from_boolean(to_type), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + ) + } + DataType::Float32 | DataType::Float64 => matches!( to_type, DataType::Boolean | DataType::Int8 @@ -1149,46 +1160,34 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b | DataType::Int64 | DataType::Float32 | DataType::Float64 - | DataType::Utf8 - ) - } - DataType::Float32 | DataType::Float64 => matches!( - to_type, - DataType::Boolean - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - ), - DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( - to_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - | DataType::Utf8 // note that there can be formatting differences - ), - DataType::Utf8 => matches!(to_type, DataType::Binary), - DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8), - DataType::Timestamp(_, _) => { - matches!( + ), + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( to_type, - DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) - ) - } - DataType::Binary => { - // note that this is not completely Spark compatible because - // DataFusion only supports binary data containing valid UTF-8 strings - matches!(to_type, DataType::Utf8) + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Utf8 // note that there can be formatting differences + ), + DataType::Utf8 => matches!(to_type, DataType::Binary), + DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8), + DataType::Timestamp(_, _) => { + matches!( + to_type, + DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) + ) + } + DataType::Binary => { + // note that this is not completely Spark compatible because + // DataFusion only supports binary data containing valid UTF-8 strings + matches!(to_type, DataType::Utf8) + } + _ => false, } - _ => false, - } } /// Cast between struct types based on logic in diff --git a/native/spark-expr/src/conversion_funcs/utils.rs b/native/spark-expr/src/conversion_funcs/utils.rs index 4abb45f725..415db2778d 100644 --- a/native/spark-expr/src/conversion_funcs/utils.rs +++ b/native/spark-expr/src/conversion_funcs/utils.rs @@ -107,3 +107,8 @@ fn trim_end(s: &str) -> &str { s } } + +#[inline] +pub fn is_identity_cast(from_type: &DataType, to_type: &DataType) -> bool { + from_type == to_type +} From 6f9eb8f7fce6c5953e2c5ab3f88734fa29c703a5 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 10 Feb 2026 15:56:22 -0800 Subject: [PATCH 03/11] refactor_boolean_cast_ops_add_tests --- .../spark-expr/src/conversion_funcs/cast.rs | 19 +------------------ .../spark-expr/src/conversion_funcs/utils.rs | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index fc8b31b1a7..caf5eb8296 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -16,6 +16,7 @@ // under the License. use crate::conversion_funcs::boolean::can_cast_from_boolean; +use crate::conversion_funcs::utils::{cast_overflow, invalid_value}; use crate::conversion_funcs::utils::{is_identity_cast, spark_cast_postprocess}; use crate::utils::array_with_timezone; use crate::EvalMode::Legacy; @@ -2379,24 +2380,6 @@ fn parse_decimal_str( Ok((final_mantissa, final_scale)) } -#[inline] -fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError { - SparkError::CastInvalidValue { - value: value.to_string(), - from_type: from_type.to_string(), - to_type: to_type.to_string(), - } -} - -#[inline] -fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError { - SparkError::CastOverFlow { - value: value.to_string(), - from_type: from_type.to_string(), - to_type: to_type.to_string(), - } -} - impl Display for Cast { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( diff --git a/native/spark-expr/src/conversion_funcs/utils.rs b/native/spark-expr/src/conversion_funcs/utils.rs index 415db2778d..3a75792c71 100644 --- a/native/spark-expr/src/conversion_funcs/utils.rs +++ b/native/spark-expr/src/conversion_funcs/utils.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::SparkError; use arrow::array::{ Array, ArrayRef, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray, }; @@ -112,3 +113,21 @@ fn trim_end(s: &str) -> &str { pub fn is_identity_cast(from_type: &DataType, to_type: &DataType) -> bool { from_type == to_type } + +#[inline] +pub fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastOverFlow { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} + +#[inline] +pub fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastInvalidValue { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} From 06f662002c2b17bd0eafb554fad1eecd130d8f13 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 10 Feb 2026 16:53:11 -0800 Subject: [PATCH 04/11] refactor_boolean_cast_ops_add_benchmarks --- native/spark-expr/Cargo.toml | 4 ++ native/spark-expr/benches/cast_boolean.rs | 74 +++++++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 native/spark-expr/benches/cast_boolean.rs diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index fd0a211b29..46489692a2 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -95,3 +95,7 @@ harness = false [[test]] name = "test_udf_registration" path = "tests/spark_expr_reg.rs" + +[[bench]] +name = "cast_boolean" +harness = false diff --git a/native/spark-expr/benches/cast_boolean.rs b/native/spark-expr/benches/cast_boolean.rs new file mode 100644 index 0000000000..90fb42c685 --- /dev/null +++ b/native/spark-expr/benches/cast_boolean.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use arrow::array::{BooleanBuilder, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; + +fn criterion_benchmark(c: &mut Criterion) { + let expr = Arc::new(Column::new("a", 0)); + let boolean_batch = create_boolean_batch(); + let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); + Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); + let cast_to_i8 = Cast::new(expr.clone(), DataType::Boolean, spark_cast_options.clone()); + let cast_to_i16 = Cast::new(expr.clone(), DataType::Boolean, spark_cast_options.clone()); + let cast_to_i32 = Cast::new(expr.clone(), DataType::Boolean, spark_cast_options.clone()); + let cast_to_i64 = Cast::new(expr.clone(), DataType::Boolean, spark_cast_options); + + let mut group = c.benchmark_group(format!("cast_bool_to_int")); + group.bench_function("i8", |b| { + b.iter(|| cast_to_i8.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("i16", |b| { + b.iter(|| cast_to_i16.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("i32", |b| { + b.iter(|| cast_to_i32.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("i64", |b| { + b.iter(|| cast_to_i64.evaluate(&boolean_batch).unwrap()); + }); +} + +fn create_boolean_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])); + let mut b = BooleanBuilder::with_capacity(1000); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + let array = b.finish(); + RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap() +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); \ No newline at end of file From fa1bb14e841970620d95ec2be7ba2cf9206c5c4f Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 11 Feb 2026 07:58:42 -0800 Subject: [PATCH 05/11] refactor_boolean_cast_ops_add_benchmarks --- native/spark-expr/benches/cast_boolean.rs | 24 +++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/native/spark-expr/benches/cast_boolean.rs b/native/spark-expr/benches/cast_boolean.rs index 90fb42c685..1378e48975 100644 --- a/native/spark-expr/benches/cast_boolean.rs +++ b/native/spark-expr/benches/cast_boolean.rs @@ -15,23 +15,26 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; use arrow::array::{BooleanBuilder, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema}; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::physical_expr::expressions::Column; use datafusion::physical_expr::PhysicalExpr; use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; +use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { let expr = Arc::new(Column::new("a", 0)); let boolean_batch = create_boolean_batch(); let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); - let cast_to_i8 = Cast::new(expr.clone(), DataType::Boolean, spark_cast_options.clone()); - let cast_to_i16 = Cast::new(expr.clone(), DataType::Boolean, spark_cast_options.clone()); - let cast_to_i32 = Cast::new(expr.clone(), DataType::Boolean, spark_cast_options.clone()); - let cast_to_i64 = Cast::new(expr.clone(), DataType::Boolean, spark_cast_options); + let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); + let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); + let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); + let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options.clone()); + let cast_to_f32 = Cast::new(expr.clone(), DataType::Float32, spark_cast_options.clone()); + let cast_to_f64 = Cast::new(expr.clone(), DataType::Float64, spark_cast_options.clone()); + let cast_to_str = Cast::new(expr, DataType::Utf8, spark_cast_options); let mut group = c.benchmark_group(format!("cast_bool_to_int")); group.bench_function("i8", |b| { @@ -46,6 +49,15 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("i64", |b| { b.iter(|| cast_to_i64.evaluate(&boolean_batch).unwrap()); }); + group.bench_function("f32", |b| { + b.iter(|| cast_to_f32.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("f64", |b| { + b.iter(|| cast_to_f64.evaluate(&boolean_batch).unwrap()); + }); + group.bench_function("str", |b| { + b.iter(|| cast_to_str.evaluate(&boolean_batch).unwrap()); + }); } fn create_boolean_batch() -> RecordBatch { @@ -71,4 +83,4 @@ criterion_group! { config = config(); targets = criterion_benchmark } -criterion_main!(benches); \ No newline at end of file +criterion_main!(benches); From e1a4f5dc4c8a4e1ea967f566dd18ed0035cb9e80 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 11 Feb 2026 08:39:53 -0800 Subject: [PATCH 06/11] refactor_boolean_cast_ops_add_benchmarks --- native/spark-expr/benches/cast_boolean.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/spark-expr/benches/cast_boolean.rs b/native/spark-expr/benches/cast_boolean.rs index 1378e48975..03a2c51bdb 100644 --- a/native/spark-expr/benches/cast_boolean.rs +++ b/native/spark-expr/benches/cast_boolean.rs @@ -36,7 +36,7 @@ fn criterion_benchmark(c: &mut Criterion) { let cast_to_f64 = Cast::new(expr.clone(), DataType::Float64, spark_cast_options.clone()); let cast_to_str = Cast::new(expr, DataType::Utf8, spark_cast_options); - let mut group = c.benchmark_group(format!("cast_bool_to_int")); + let mut group = c.benchmark_group("cast_bool_to_int".to_string()); group.bench_function("i8", |b| { b.iter(|| cast_to_i8.evaluate(&boolean_batch).unwrap()); }); From d1872ab84990644ad941c7b6cc26ab696afccc07 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 11 Feb 2026 16:45:29 -0800 Subject: [PATCH 07/11] refactor_boolean_cast_ops_add_benchmarks_rebase_main --- native/spark-expr/Cargo.toml | 2 +- .../{cast_boolean.rs => cast_from_boolean.rs} | 8 +++- .../src/conversion_funcs/boolean.rs | 37 +++++++++++++++++++ .../spark-expr/src/conversion_funcs/cast.rs | 12 +----- 4 files changed, 45 insertions(+), 14 deletions(-) rename native/spark-expr/benches/{cast_boolean.rs => cast_from_boolean.rs} (90%) diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 46489692a2..bcfedb15d1 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -97,5 +97,5 @@ name = "test_udf_registration" path = "tests/spark_expr_reg.rs" [[bench]] -name = "cast_boolean" +name = "cast_from_boolean" harness = false diff --git a/native/spark-expr/benches/cast_boolean.rs b/native/spark-expr/benches/cast_from_boolean.rs similarity index 90% rename from native/spark-expr/benches/cast_boolean.rs rename to native/spark-expr/benches/cast_from_boolean.rs index 03a2c51bdb..db8e77f61b 100644 --- a/native/spark-expr/benches/cast_boolean.rs +++ b/native/spark-expr/benches/cast_from_boolean.rs @@ -34,9 +34,10 @@ fn criterion_benchmark(c: &mut Criterion) { let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options.clone()); let cast_to_f32 = Cast::new(expr.clone(), DataType::Float32, spark_cast_options.clone()); let cast_to_f64 = Cast::new(expr.clone(), DataType::Float64, spark_cast_options.clone()); - let cast_to_str = Cast::new(expr, DataType::Utf8, spark_cast_options); + let cast_to_str = Cast::new(expr.clone(), DataType::Utf8, spark_cast_options.clone()); + let cast_to_decimal = Cast::new(expr, DataType::Decimal128(10, 4), spark_cast_options); - let mut group = c.benchmark_group("cast_bool_to_int".to_string()); + let mut group = c.benchmark_group("cast_bool".to_string()); group.bench_function("i8", |b| { b.iter(|| cast_to_i8.evaluate(&boolean_batch).unwrap()); }); @@ -58,6 +59,9 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function("str", |b| { b.iter(|| cast_to_str.evaluate(&boolean_batch).unwrap()); }); + group.bench_function("decimal", |b| { + b.iter(|| cast_to_decimal.evaluate(&boolean_batch).unwrap()); + }); } fn create_boolean_batch() -> RecordBatch { diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs b/native/spark-expr/src/conversion_funcs/boolean.rs index b46d8667b1..42c5d2bb93 100644 --- a/native/spark-expr/src/conversion_funcs/boolean.rs +++ b/native/spark-expr/src/conversion_funcs/boolean.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. +use crate::SparkResult; +use arrow::array::{ArrayRef, AsArray, Decimal128Array}; use arrow::datatypes::DataType; +use std::sync::Arc; pub fn can_cast_from_boolean(to_type: &DataType) -> bool { use DataType::*; @@ -25,6 +28,21 @@ pub fn can_cast_from_boolean(to_type: &DataType) -> bool { ) } +// only incompatible boolean cast +pub fn cast_boolean_to_decimal( + array: &ArrayRef, + precision: u8, + scale: i8, +) -> SparkResult { + let bool_array = array.as_boolean(); + let scaled_val = 10_i128.pow(scale as u32); + let result: Decimal128Array = bool_array + .iter() + .map(|v| v.map(|b| if b { scaled_val } else { 0 })) + .collect(); + Ok(Arc::new(result.with_precision_and_scale(precision, scale)?)) +} + #[cfg(test)] mod tests { use super::*; @@ -34,6 +52,7 @@ mod tests { Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, }; + use arrow::datatypes::DataType::Decimal128; use std::sync::Arc; fn test_input_bool_array() -> ArrayRef { @@ -54,6 +73,7 @@ mod tests { assert!(can_cast_from_boolean(&DataType::Float32)); assert!(can_cast_from_boolean(&DataType::Float64)); assert!(can_cast_from_boolean(&DataType::Utf8)); + assert!(can_cast_from_boolean(&DataType::Decimal128(10, 4))); assert!(!can_cast_from_boolean(&DataType::Null)); } @@ -154,4 +174,21 @@ mod tests { assert_eq!(arr.value(1), "false"); assert!(arr.is_null(2)); } + + #[test] + fn test_bool_to_decimal_cast() { + let result = cast_array( + test_input_bool_array(), + &Decimal128(10, 4), + &test_input_spark_opts(), + ) + .unwrap(); + let expected_arr = Decimal128Array::from(vec![10000_i128, 0_i128]) + .with_precision_and_scale(10, 4) + .unwrap(); + let arr = result.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), expected_arr.value(0)); + assert_eq!(arr.value(1), expected_arr.value(1)); + assert!(arr.is_null(2)); + } } diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index caf5eb8296..8a2df093bd 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::conversion_funcs::boolean::can_cast_from_boolean; +use crate::conversion_funcs::boolean::{can_cast_from_boolean, cast_boolean_to_decimal}; use crate::conversion_funcs::utils::{cast_overflow, invalid_value}; use crate::conversion_funcs::utils::{is_identity_cast, spark_cast_postprocess}; use crate::utils::array_with_timezone; @@ -972,16 +972,6 @@ fn cast_date_to_timestamp( )) } -fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) -> SparkResult { - let bool_array = array.as_boolean(); - let scaled_val = 10_i128.pow(scale as u32); - let result: Decimal128Array = bool_array - .iter() - .map(|v| v.map(|b| if b { scaled_val } else { 0 })) - .collect(); - Ok(Arc::new(result.with_precision_and_scale(precision, scale)?)) -} - fn cast_string_to_float( array: &ArrayRef, to_type: &DataType, From b9d75d5d38445667ea5dbdc24276cf1451fcab92 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 11 Feb 2026 22:35:56 -0800 Subject: [PATCH 08/11] refactor_boolean_cast_ops_add_benchmarks_rebase_main --- native/spark-expr/src/conversion_funcs/boolean.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs b/native/spark-expr/src/conversion_funcs/boolean.rs index 42c5d2bb93..c36d5299f4 100644 --- a/native/spark-expr/src/conversion_funcs/boolean.rs +++ b/native/spark-expr/src/conversion_funcs/boolean.rs @@ -24,7 +24,7 @@ pub fn can_cast_from_boolean(to_type: &DataType) -> bool { use DataType::*; matches!( to_type, - Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 | Decimal128(_, _) ) } From aac9fc6c6edf10e115fc60ebed684caad1e6c3e7 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Sat, 14 Feb 2026 18:13:39 -0800 Subject: [PATCH 09/11] refactor_boolean_cast_ops_add_benchmarks_rebase_main --- native/spark-expr/src/conversion_funcs/boolean.rs | 2 +- native/spark-expr/src/conversion_funcs/cast.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs b/native/spark-expr/src/conversion_funcs/boolean.rs index c36d5299f4..0c2341ac66 100644 --- a/native/spark-expr/src/conversion_funcs/boolean.rs +++ b/native/spark-expr/src/conversion_funcs/boolean.rs @@ -28,7 +28,7 @@ pub fn can_cast_from_boolean(to_type: &DataType) -> bool { ) } -// only incompatible boolean cast +// only DF incompatible boolean cast pub fn cast_boolean_to_decimal( array: &ArrayRef, precision: u8, diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 8a2df093bd..a1e8754cb5 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -754,7 +754,7 @@ fn dict_from_values( Ok(Arc::new(dict_array)) } -pub fn cast_array( +pub(crate) fn cast_array( array: ArrayRef, to_type: &DataType, cast_options: &SparkCastOptions, From 12fcedc3c1a0b1066b9197405319a2bd1c9fda71 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 18 Feb 2026 13:57:33 -0800 Subject: [PATCH 10/11] refactor_bool_cast_module --- .../src/conversion_funcs/boolean.rs | 28 ++++++++++--------- .../spark-expr/src/conversion_funcs/cast.rs | 10 ++++--- .../spark-expr/src/conversion_funcs/utils.rs | 5 ---- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs b/native/spark-expr/src/conversion_funcs/boolean.rs index 0c2341ac66..db288fa32a 100644 --- a/native/spark-expr/src/conversion_funcs/boolean.rs +++ b/native/spark-expr/src/conversion_funcs/boolean.rs @@ -20,11 +20,11 @@ use arrow::array::{ArrayRef, AsArray, Decimal128Array}; use arrow::datatypes::DataType; use std::sync::Arc; -pub fn can_cast_from_boolean(to_type: &DataType) -> bool { +pub fn is_df_cast_from_bool_spark_compatible(to_type: &DataType) -> bool { use DataType::*; matches!( to_type, - Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 | Decimal128(_, _) + Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 ) } @@ -64,17 +64,19 @@ mod tests { } #[test] - fn test_can_cast_from_boolean() { - assert!(can_cast_from_boolean(&DataType::Boolean)); - assert!(can_cast_from_boolean(&DataType::Int8)); - assert!(can_cast_from_boolean(&DataType::Int16)); - assert!(can_cast_from_boolean(&DataType::Int32)); - assert!(can_cast_from_boolean(&DataType::Int64)); - assert!(can_cast_from_boolean(&DataType::Float32)); - assert!(can_cast_from_boolean(&DataType::Float64)); - assert!(can_cast_from_boolean(&DataType::Utf8)); - assert!(can_cast_from_boolean(&DataType::Decimal128(10, 4))); - assert!(!can_cast_from_boolean(&DataType::Null)); + fn test_is_df_cast_from_bool_spark_compatible() { + assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Boolean)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int8)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int16)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int32)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Int64)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float32)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Float64)); + assert!(is_df_cast_from_bool_spark_compatible(&DataType::Utf8)); + assert!(!is_df_cast_from_bool_spark_compatible( + &DataType::Decimal128(10, 4) + )); + assert!(!is_df_cast_from_bool_spark_compatible(&DataType::Null)); } #[test] diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index a1e8754cb5..74be1d733b 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::conversion_funcs::boolean::{can_cast_from_boolean, cast_boolean_to_decimal}; +use crate::conversion_funcs::boolean::{ + cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible, +}; +use crate::conversion_funcs::utils::spark_cast_postprocess; use crate::conversion_funcs::utils::{cast_overflow, invalid_value}; -use crate::conversion_funcs::utils::{is_identity_cast, spark_cast_postprocess}; use crate::utils::array_with_timezone; use crate::EvalMode::Legacy; use crate::{timezone, BinaryOutputStyle}; @@ -1123,12 +1125,12 @@ fn cast_binary_formatter(value: &[u8]) -> String { /// Determines if DataFusion supports the given cast in a way that is /// compatible with Spark fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { - is_identity_cast(from_type, to_type) + from_type == to_type || match from_type { DataType::Null => { matches!(to_type, DataType::List(_)) } - DataType::Boolean => can_cast_from_boolean(to_type), + DataType::Boolean => is_df_cast_from_bool_spark_compatible(to_type), DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { matches!( to_type, diff --git a/native/spark-expr/src/conversion_funcs/utils.rs b/native/spark-expr/src/conversion_funcs/utils.rs index 3a75792c71..2f6ea06b4d 100644 --- a/native/spark-expr/src/conversion_funcs/utils.rs +++ b/native/spark-expr/src/conversion_funcs/utils.rs @@ -109,11 +109,6 @@ fn trim_end(s: &str) -> &str { } } -#[inline] -pub fn is_identity_cast(from_type: &DataType, to_type: &DataType) -> bool { - from_type == to_type -} - #[inline] pub fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError { SparkError::CastOverFlow { From 05ece881c2a9cdc498140b26d64993479d061b16 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 19 Feb 2026 07:47:58 -0800 Subject: [PATCH 11/11] refactor_bool_cast_module --- .../spark-expr/benches/cast_from_boolean.rs | 1 - .../spark-expr/src/conversion_funcs/cast.rs | 96 ++++++++++--------- .../spark-expr/src/conversion_funcs/utils.rs | 2 +- 3 files changed, 51 insertions(+), 48 deletions(-) diff --git a/native/spark-expr/benches/cast_from_boolean.rs b/native/spark-expr/benches/cast_from_boolean.rs index db8e77f61b..dbb986df91 100644 --- a/native/spark-expr/benches/cast_from_boolean.rs +++ b/native/spark-expr/benches/cast_from_boolean.rs @@ -27,7 +27,6 @@ fn criterion_benchmark(c: &mut Criterion) { let expr = Arc::new(Column::new("a", 0)); let boolean_batch = create_boolean_batch(); let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); - Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index e3ee71cac2..004668b8f2 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -69,6 +69,8 @@ use std::{ static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); +pub(crate) const MICROS_PER_SECOND: i64 = 1000000; + static CAST_OPTIONS: CastOptions = CastOptions { safe: true, format_options: FormatOptions::new() @@ -1166,26 +1168,16 @@ fn cast_binary_formatter(value: &[u8]) -> String { /// Determines if DataFusion supports the given cast in a way that is /// compatible with Spark fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { - from_type == to_type - || match from_type { - DataType::Null => { - matches!(to_type, DataType::List(_)) - } - DataType::Boolean => is_df_cast_from_bool_spark_compatible(to_type), - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - matches!( - to_type, - DataType::Boolean - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Utf8 - ) - } - DataType::Float32 | DataType::Float64 => matches!( + if from_type == to_type { + return true; + } + match from_type { + DataType::Null => { + matches!(to_type, DataType::List(_)) + } + DataType::Boolean => is_df_cast_from_bool_spark_compatible(to_type), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + matches!( to_type, DataType::Boolean | DataType::Int8 @@ -1194,34 +1186,46 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b | DataType::Int64 | DataType::Float32 | DataType::Float64 - ), - DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( + | DataType::Utf8 + ) + } + DataType::Float32 | DataType::Float64 => matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ), + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Utf8 // note that there can be formatting differences + ), + DataType::Utf8 => matches!(to_type, DataType::Binary), + DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8), + DataType::Timestamp(_, _) => { + matches!( to_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - | DataType::Utf8 // note that there can be formatting differences - ), - DataType::Utf8 => matches!(to_type, DataType::Binary), - DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8), - DataType::Timestamp(_, _) => { - matches!( - to_type, - DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) - ) - } - DataType::Binary => { - // note that this is not completely Spark compatible because - // DataFusion only supports binary data containing valid UTF-8 strings - matches!(to_type, DataType::Utf8) - } - _ => false, + DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) + ) + } + DataType::Binary => { + // note that this is not completely Spark compatible because + // DataFusion only supports binary data containing valid UTF-8 strings + matches!(to_type, DataType::Utf8) } + _ => false, + } } /// Cast between struct types based on logic in diff --git a/native/spark-expr/src/conversion_funcs/utils.rs b/native/spark-expr/src/conversion_funcs/utils.rs index 2f6ea06b4d..8b8d974ffe 100644 --- a/native/spark-expr/src/conversion_funcs/utils.rs +++ b/native/spark-expr/src/conversion_funcs/utils.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::cast::MICROS_PER_SECOND; use crate::SparkError; use arrow::array::{ Array, ArrayRef, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray, @@ -26,7 +27,6 @@ use datafusion::common::cast::as_generic_string_array; use num::integer::div_floor; use std::sync::Arc; -const MICROS_PER_SECOND: i64 = 1000000; /// A fork & modified version of Arrow's `unary_dyn` which is being deprecated pub fn unary_dyn(array: &ArrayRef, op: F) -> Result where