@@ -23,45 +23,75 @@ use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
2323use std:: sync:: Arc ;
2424
2525fn criterion_benchmark ( c : & mut Criterion ) {
26- let batch = create_utf8_batch ( ) ;
26+ let int_batch = create_int_string_batch ( ) ;
27+ let decimal_batch = create_decimal_string_batch ( ) ;
2728 let expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
29+
30+ for ( mode, mode_name) in [
31+ ( EvalMode :: Legacy , "legacy" ) ,
32+ ( EvalMode :: Ansi , "ansi" ) ,
33+ ( EvalMode :: Try , "try" ) ,
34+ ] {
35+ let spark_cast_options = SparkCastOptions :: new ( mode, "" , false ) ;
36+ let cast_to_i32 = Cast :: new ( expr. clone ( ) , DataType :: Int32 , spark_cast_options. clone ( ) ) ;
37+ let cast_to_i64 = Cast :: new ( expr. clone ( ) , DataType :: Int64 , spark_cast_options) ;
38+
39+ let mut group = c. benchmark_group ( format ! ( "cast_string_to_int/{}" , mode_name) ) ;
40+ group. bench_function ( "i32" , |b| {
41+ b. iter ( || cast_to_i32. evaluate ( & int_batch) . unwrap ( ) ) ;
42+ } ) ;
43+ group. bench_function ( "i64" , |b| {
44+ b. iter ( || cast_to_i64. evaluate ( & int_batch) . unwrap ( ) ) ;
45+ } ) ;
46+ group. finish ( ) ;
47+ }
48+
49+ // Benchmark decimal truncation (Legacy mode only)
2850 let spark_cast_options = SparkCastOptions :: new ( EvalMode :: Legacy , "" , false ) ;
29- let cast_string_to_i8 = Cast :: new ( expr. clone ( ) , DataType :: Int8 , spark_cast_options. clone ( ) ) ;
30- let cast_string_to_i16 = Cast :: new ( expr. clone ( ) , DataType :: Int16 , spark_cast_options. clone ( ) ) ;
31- let cast_string_to_i32 = Cast :: new ( expr. clone ( ) , DataType :: Int32 , spark_cast_options. clone ( ) ) ;
32- let cast_string_to_i64 = Cast :: new ( expr, DataType :: Int64 , spark_cast_options) ;
51+ let cast_to_i32 = Cast :: new ( expr. clone ( ) , DataType :: Int32 , spark_cast_options. clone ( ) ) ;
52+ let cast_to_i64 = Cast :: new ( expr. clone ( ) , DataType :: Int64 , spark_cast_options) ;
3353
34- let mut group = c. benchmark_group ( "cast_string_to_int" ) ;
35- group. bench_function ( "cast_string_to_i8 " , |b| {
36- b. iter ( || cast_string_to_i8 . evaluate ( & batch ) . unwrap ( ) ) ;
54+ let mut group = c. benchmark_group ( "cast_string_to_int/legacy_decimals " ) ;
55+ group. bench_function ( "i32 " , |b| {
56+ b. iter ( || cast_to_i32 . evaluate ( & decimal_batch ) . unwrap ( ) ) ;
3757 } ) ;
38- group. bench_function ( "cast_string_to_i16" , |b| {
39- b. iter ( || cast_string_to_i16. evaluate ( & batch) . unwrap ( ) ) ;
40- } ) ;
41- group. bench_function ( "cast_string_to_i32" , |b| {
42- b. iter ( || cast_string_to_i32. evaluate ( & batch) . unwrap ( ) ) ;
43- } ) ;
44- group. bench_function ( "cast_string_to_i64" , |b| {
45- b. iter ( || cast_string_to_i64. evaluate ( & batch) . unwrap ( ) ) ;
58+ group. bench_function ( "i64" , |b| {
59+ b. iter ( || cast_to_i64. evaluate ( & decimal_batch) . unwrap ( ) ) ;
4660 } ) ;
61+ group. finish ( ) ;
4762}
4863
49- // Create UTF8 batch with strings representing ints, floats, nulls
50- fn create_utf8_batch ( ) -> RecordBatch {
64+ /// Create batch with valid integer strings (works for all eval modes)
65+ fn create_int_string_batch ( ) -> RecordBatch {
5166 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Utf8 , true ) ] ) ) ;
5267 let mut b = StringBuilder :: new ( ) ;
5368 for i in 0 ..1000 {
5469 if i % 10 == 0 {
5570 b. append_null ( ) ;
56- } else if i % 2 == 0 {
57- b. append_value ( format ! ( "{}" , rand:: random:: <f64 >( ) ) ) ;
5871 } else {
59- b. append_value ( format ! ( "{}" , rand:: random:: <i64 >( ) ) ) ;
72+ b. append_value ( format ! ( "{}" , rand:: random:: <i32 >( ) ) ) ;
6073 }
6174 }
6275 let array = b. finish ( ) ;
76+ RecordBatch :: try_new ( schema, vec ! [ Arc :: new( array) ] ) . unwrap ( )
77+ }
6378
64- RecordBatch :: try_new ( schema. clone ( ) , vec ! [ Arc :: new( array) ] ) . unwrap ( )
79+ /// Create batch with decimal strings (for Legacy mode decimal truncation)
80+ fn create_decimal_string_batch ( ) -> RecordBatch {
81+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Utf8 , true ) ] ) ) ;
82+ let mut b = StringBuilder :: new ( ) ;
83+ for i in 0 ..1000 {
84+ if i % 10 == 0 {
85+ b. append_null ( ) ;
86+ } else {
87+ // Generate integers with decimal portions to test truncation
88+ let int_part: i32 = rand:: random ( ) ;
89+ let dec_part: u32 = rand:: random :: < u32 > ( ) % 1000 ;
90+ b. append_value ( format ! ( "{}.{}" , int_part, dec_part) ) ;
91+ }
92+ }
93+ let array = b. finish ( ) ;
94+ RecordBatch :: try_new ( schema, vec ! [ Arc :: new( array) ] ) . unwrap ( )
6595}
6696
6797fn config ( ) -> Criterion {
0 commit comments