Skip to content

Commit b56aa04

Browse files
committed
fix regression and update benchmark
1 parent 37b90fb commit b56aa04

2 files changed

Lines changed: 33 additions & 0 deletions

File tree

native/spark-expr/benches/cast_from_string.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
2323
use std::sync::Arc;
2424

2525
fn criterion_benchmark(c: &mut Criterion) {
26+
let small_int_batch = create_small_int_string_batch();
2627
let int_batch = create_int_string_batch();
2728
let decimal_batch = create_decimal_string_batch();
2829
let expr = Arc::new(Column::new("a", 0));
@@ -33,10 +34,18 @@ fn criterion_benchmark(c: &mut Criterion) {
3334
(EvalMode::Try, "try"),
3435
] {
3536
let spark_cast_options = SparkCastOptions::new(mode, "", false);
37+
let cast_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
38+
let cast_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
3639
let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
3740
let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options);
3841

3942
let mut group = c.benchmark_group(format!("cast_string_to_int/{}", mode_name));
43+
group.bench_function("i8", |b| {
44+
b.iter(|| cast_to_i8.evaluate(&small_int_batch).unwrap());
45+
});
46+
group.bench_function("i16", |b| {
47+
b.iter(|| cast_to_i16.evaluate(&small_int_batch).unwrap());
48+
});
4049
group.bench_function("i32", |b| {
4150
b.iter(|| cast_to_i32.evaluate(&int_batch).unwrap());
4251
});
@@ -61,6 +70,21 @@ fn criterion_benchmark(c: &mut Criterion) {
6170
group.finish();
6271
}
6372

73+
/// Create batch with small integer strings that fit in i8 range (for i8/i16 benchmarks)
74+
fn create_small_int_string_batch() -> RecordBatch {
75+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
76+
let mut b = StringBuilder::new();
77+
for i in 0..1000 {
78+
if i % 10 == 0 {
79+
b.append_null();
80+
} else {
81+
b.append_value(format!("{}", rand::random::<i8>()));
82+
}
83+
}
84+
let array = b.finish();
85+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
86+
}
87+
6488
/// Create batch with valid integer strings (works for all eval modes)
6589
fn create_int_string_batch() -> RecordBatch {
6690
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,6 +3004,15 @@ mod tests {
30043004

30053005
use super::*;
30063006

3007+
/// Test helper that wraps the mode-specific parse functions
3008+
fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i8>> {
3009+
match eval_mode {
3010+
EvalMode::Legacy => parse_string_to_i8_legacy(str),
3011+
EvalMode::Ansi => parse_string_to_i8_ansi(str),
3012+
EvalMode::Try => parse_string_to_i8_try(str),
3013+
}
3014+
}
3015+
30073016
#[test]
30083017
#[cfg_attr(miri, ignore)] // test takes too long with miri
30093018
fn timestamp_parser_test() {

0 commit comments

Comments
 (0)