Skip to content

Commit dae6a25

Browse files
committed
refactor_boolean_cast_ops_add_benchmarks_rebase_main
1 parent f8c3a64 commit dae6a25

4 files changed

Lines changed: 45 additions & 14 deletions

File tree

native/spark-expr/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,5 @@ name = "test_udf_registration"
9797
path = "tests/spark_expr_reg.rs"
9898

9999
[[bench]]
100-
name = "cast_boolean"
100+
name = "cast_from_boolean"
101101
harness = false

native/spark-expr/benches/cast_boolean.rs renamed to native/spark-expr/benches/cast_from_boolean.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ fn criterion_benchmark(c: &mut Criterion) {
3434
let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options.clone());
3535
let cast_to_f32 = Cast::new(expr.clone(), DataType::Float32, spark_cast_options.clone());
3636
let cast_to_f64 = Cast::new(expr.clone(), DataType::Float64, spark_cast_options.clone());
37-
let cast_to_str = Cast::new(expr, DataType::Utf8, spark_cast_options);
37+
let cast_to_str = Cast::new(expr.clone(), DataType::Utf8, spark_cast_options.clone());
38+
let cast_to_decimal = Cast::new(expr, DataType::Decimal128(10, 4), spark_cast_options);
3839

39-
let mut group = c.benchmark_group("cast_bool_to_int".to_string());
40+
let mut group = c.benchmark_group("cast_bool".to_string());
4041
group.bench_function("i8", |b| {
4142
b.iter(|| cast_to_i8.evaluate(&boolean_batch).unwrap());
4243
});
@@ -58,6 +59,9 @@ fn criterion_benchmark(c: &mut Criterion) {
5859
group.bench_function("str", |b| {
5960
b.iter(|| cast_to_str.evaluate(&boolean_batch).unwrap());
6061
});
62+
group.bench_function("decimal", |b| {
63+
b.iter(|| cast_to_decimal.evaluate(&boolean_batch).unwrap());
64+
});
6165
}
6266

6367
fn create_boolean_batch() -> RecordBatch {

native/spark-expr/src/conversion_funcs/boolean.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::SparkResult;
19+
use arrow::array::{ArrayRef, AsArray, Decimal128Array};
1820
use arrow::datatypes::DataType;
21+
use std::sync::Arc;
1922

2023
pub fn can_cast_from_boolean(to_type: &DataType) -> bool {
2124
use DataType::*;
@@ -25,6 +28,21 @@ pub fn can_cast_from_boolean(to_type: &DataType) -> bool {
2528
)
2629
}
2730

31+
// only incompatible boolean cast
32+
pub fn cast_boolean_to_decimal(
33+
array: &ArrayRef,
34+
precision: u8,
35+
scale: i8,
36+
) -> SparkResult<ArrayRef> {
37+
let bool_array = array.as_boolean();
38+
let scaled_val = 10_i128.pow(scale as u32);
39+
let result: Decimal128Array = bool_array
40+
.iter()
41+
.map(|v| v.map(|b| if b { scaled_val } else { 0 }))
42+
.collect();
43+
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
44+
}
45+
2846
#[cfg(test)]
2947
mod tests {
3048
use super::*;
@@ -34,6 +52,7 @@ mod tests {
3452
Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
3553
Int64Array, Int8Array, StringArray,
3654
};
55+
use arrow::datatypes::DataType::Decimal128;
3756
use std::sync::Arc;
3857

3958
fn test_input_bool_array() -> ArrayRef {
@@ -54,6 +73,7 @@ mod tests {
5473
assert!(can_cast_from_boolean(&DataType::Float32));
5574
assert!(can_cast_from_boolean(&DataType::Float64));
5675
assert!(can_cast_from_boolean(&DataType::Utf8));
76+
assert!(can_cast_from_boolean(&DataType::Decimal128(10, 4)));
5777
assert!(!can_cast_from_boolean(&DataType::Null));
5878
}
5979

@@ -154,4 +174,21 @@ mod tests {
154174
assert_eq!(arr.value(1), "false");
155175
assert!(arr.is_null(2));
156176
}
177+
178+
#[test]
179+
fn test_bool_to_decimal_cast() {
180+
let result = cast_array(
181+
test_input_bool_array(),
182+
&Decimal128(10, 4),
183+
&test_input_spark_opts(),
184+
)
185+
.unwrap();
186+
let expected_arr = Decimal128Array::from(vec![10000_i128, 0_i128])
187+
.with_precision_and_scale(10, 4)
188+
.unwrap();
189+
let arr = result.as_any().downcast_ref::<Decimal128Array>().unwrap();
190+
assert_eq!(arr.value(0), expected_arr.value(0));
191+
assert_eq!(arr.value(1), expected_arr.value(1));
192+
assert!(arr.is_null(2));
193+
}
157194
}

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::conversion_funcs::boolean::can_cast_from_boolean;
18+
use crate::conversion_funcs::boolean::{can_cast_from_boolean, cast_boolean_to_decimal};
1919
use crate::conversion_funcs::utils::{cast_overflow, invalid_value};
2020
use crate::conversion_funcs::utils::{is_identity_cast, spark_cast_postprocess};
2121
use crate::utils::array_with_timezone;
@@ -1187,16 +1187,6 @@ fn cast_date_to_timestamp(
11871187
))
11881188
}
11891189

1190-
fn cast_boolean_to_decimal(array: &ArrayRef, precision: u8, scale: i8) -> SparkResult<ArrayRef> {
1191-
let bool_array = array.as_boolean();
1192-
let scaled_val = 10_i128.pow(scale as u32);
1193-
let result: Decimal128Array = bool_array
1194-
.iter()
1195-
.map(|v| v.map(|b| if b { scaled_val } else { 0 }))
1196-
.collect();
1197-
Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
1198-
}
1199-
12001190
fn cast_string_to_float(
12011191
array: &ArrayRef,
12021192
to_type: &DataType,

0 commit comments

Comments
 (0)