Skip to content

Commit fec6217

Browse files
committed
improve criterion benchmarks for cast string to int
1 parent 069681a commit fec6217

1 file changed

Lines changed: 52 additions & 22 deletions

File tree

native/spark-expr/benches/cast_from_string.rs

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,45 +23,75 @@ use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
2323
use std::sync::Arc;
2424

2525
fn criterion_benchmark(c: &mut Criterion) {
26-
let batch = create_utf8_batch();
26+
let int_batch = create_int_string_batch();
27+
let decimal_batch = create_decimal_string_batch();
2728
let expr = Arc::new(Column::new("a", 0));
29+
30+
for (mode, mode_name) in [
31+
(EvalMode::Legacy, "legacy"),
32+
(EvalMode::Ansi, "ansi"),
33+
(EvalMode::Try, "try"),
34+
] {
35+
let spark_cast_options = SparkCastOptions::new(mode, "", false);
36+
let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
37+
let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options);
38+
39+
let mut group = c.benchmark_group(format!("cast_string_to_int/{}", mode_name));
40+
group.bench_function("i32", |b| {
41+
b.iter(|| cast_to_i32.evaluate(&int_batch).unwrap());
42+
});
43+
group.bench_function("i64", |b| {
44+
b.iter(|| cast_to_i64.evaluate(&int_batch).unwrap());
45+
});
46+
group.finish();
47+
}
48+
49+
// Benchmark decimal truncation (Legacy mode only)
2850
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "", false);
29-
let cast_string_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
30-
let cast_string_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
31-
let cast_string_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
32-
let cast_string_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options);
51+
let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
52+
let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options);
3353

34-
let mut group = c.benchmark_group("cast_string_to_int");
35-
group.bench_function("cast_string_to_i8", |b| {
36-
b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap());
54+
let mut group = c.benchmark_group("cast_string_to_int/legacy_decimals");
55+
group.bench_function("i32", |b| {
56+
b.iter(|| cast_to_i32.evaluate(&decimal_batch).unwrap());
3757
});
38-
group.bench_function("cast_string_to_i16", |b| {
39-
b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap());
40-
});
41-
group.bench_function("cast_string_to_i32", |b| {
42-
b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap());
43-
});
44-
group.bench_function("cast_string_to_i64", |b| {
45-
b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap());
58+
group.bench_function("i64", |b| {
59+
b.iter(|| cast_to_i64.evaluate(&decimal_batch).unwrap());
4660
});
61+
group.finish();
4762
}
4863

49-
// Create UTF8 batch with strings representing ints, floats, nulls
50-
fn create_utf8_batch() -> RecordBatch {
64+
/// Create batch with valid integer strings (works for all eval modes)
65+
fn create_int_string_batch() -> RecordBatch {
5166
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
5267
let mut b = StringBuilder::new();
5368
for i in 0..1000 {
5469
if i % 10 == 0 {
5570
b.append_null();
56-
} else if i % 2 == 0 {
57-
b.append_value(format!("{}", rand::random::<f64>()));
5871
} else {
59-
b.append_value(format!("{}", rand::random::<i64>()));
72+
b.append_value(format!("{}", rand::random::<i32>()));
6073
}
6174
}
6275
let array = b.finish();
76+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
77+
}
6378

64-
RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap()
79+
/// Create batch with decimal strings (for Legacy mode decimal truncation)
80+
fn create_decimal_string_batch() -> RecordBatch {
81+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
82+
let mut b = StringBuilder::new();
83+
for i in 0..1000 {
84+
if i % 10 == 0 {
85+
b.append_null();
86+
} else {
87+
// Generate integers with decimal portions to test truncation
88+
let int_part: i32 = rand::random();
89+
let dec_part: u32 = rand::random::<u32>() % 1000;
90+
b.append_value(format!("{}.{}", int_part, dec_part));
91+
}
92+
}
93+
let array = b.finish();
94+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
6595
}
6696

6797
fn config() -> Criterion {

0 commit comments

Comments
 (0)