Skip to content

Commit e28d52a

Browse files
committed
remove per-row eval mode check and expand benchmarks
1 parent cbf68bb commit e28d52a

2 files changed

Lines changed: 155 additions & 90 deletions

File tree

native/spark-expr/benches/cast_from_string.rs

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,45 +23,75 @@ use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
2323
use std::sync::Arc;
2424

2525
fn criterion_benchmark(c: &mut Criterion) {
26-
let batch = create_utf8_batch();
26+
let int_batch = create_int_string_batch();
27+
let decimal_batch = create_decimal_string_batch();
2728
let expr = Arc::new(Column::new("a", 0));
29+
30+
for (mode, mode_name) in [
31+
(EvalMode::Legacy, "legacy"),
32+
(EvalMode::Ansi, "ansi"),
33+
(EvalMode::Try, "try"),
34+
] {
35+
let spark_cast_options = SparkCastOptions::new(mode, "", false);
36+
let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
37+
let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options);
38+
39+
let mut group = c.benchmark_group(format!("cast_string_to_int/{}", mode_name));
40+
group.bench_function("i32", |b| {
41+
b.iter(|| cast_to_i32.evaluate(&int_batch).unwrap());
42+
});
43+
group.bench_function("i64", |b| {
44+
b.iter(|| cast_to_i64.evaluate(&int_batch).unwrap());
45+
});
46+
group.finish();
47+
}
48+
49+
// Benchmark decimal truncation (Legacy mode only)
2850
let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "", false);
29-
let cast_string_to_i8 = Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone());
30-
let cast_string_to_i16 = Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone());
31-
let cast_string_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
32-
let cast_string_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options);
51+
let cast_to_i32 = Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone());
52+
let cast_to_i64 = Cast::new(expr.clone(), DataType::Int64, spark_cast_options);
3353

34-
let mut group = c.benchmark_group("cast_string_to_int");
35-
group.bench_function("cast_string_to_i8", |b| {
36-
b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap());
54+
let mut group = c.benchmark_group("cast_string_to_int/legacy_decimals");
55+
group.bench_function("i32", |b| {
56+
b.iter(|| cast_to_i32.evaluate(&decimal_batch).unwrap());
3757
});
38-
group.bench_function("cast_string_to_i16", |b| {
39-
b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap());
40-
});
41-
group.bench_function("cast_string_to_i32", |b| {
42-
b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap());
43-
});
44-
group.bench_function("cast_string_to_i64", |b| {
45-
b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap());
58+
group.bench_function("i64", |b| {
59+
b.iter(|| cast_to_i64.evaluate(&decimal_batch).unwrap());
4660
});
61+
group.finish();
4762
}
4863

49-
// Create UTF8 batch with strings representing ints, floats, nulls
50-
fn create_utf8_batch() -> RecordBatch {
64+
/// Create batch with valid integer strings (works for all eval modes)
65+
fn create_int_string_batch() -> RecordBatch {
5166
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
5267
let mut b = StringBuilder::new();
5368
for i in 0..1000 {
5469
if i % 10 == 0 {
5570
b.append_null();
56-
} else if i % 2 == 0 {
57-
b.append_value(format!("{}", rand::random::<f64>()));
5871
} else {
59-
b.append_value(format!("{}", rand::random::<i64>()));
72+
b.append_value(format!("{}", rand::random::<i32>()));
6073
}
6174
}
6275
let array = b.finish();
76+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
77+
}
6378

64-
RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap()
79+
/// Create batch with decimal strings (for Legacy mode decimal truncation)
80+
fn create_decimal_string_batch() -> RecordBatch {
81+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
82+
let mut b = StringBuilder::new();
83+
for i in 0..1000 {
84+
if i % 10 == 0 {
85+
b.append_null();
86+
} else {
87+
// Generate integers with decimal portions to test truncation
88+
let int_part: i32 = rand::random();
89+
let dec_part: u32 = rand::random::<u32>() % 1000;
90+
b.append_value(format!("{}.{}", int_part, dec_part));
91+
}
92+
}
93+
let array = b.finish();
94+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
6595
}
6696

6797
fn config() -> Criterion {

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 103 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -386,12 +386,13 @@ fn can_cast_from_decimal(
386386
}
387387

388388
macro_rules! cast_utf8_to_int {
389-
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
389+
($array:expr, $array_type:ty, $parse_fn:expr) => {{
390390
let len = $array.len();
391391
let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
392+
let parse_fn = $parse_fn;
392393
if $array.null_count() == 0 {
393394
for i in 0..len {
394-
if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
395+
if let Some(cast_value) = parse_fn($array.value(i))? {
395396
cast_array.append_value(cast_value);
396397
} else {
397398
cast_array.append_null()
@@ -401,7 +402,7 @@ macro_rules! cast_utf8_to_int {
401402
for i in 0..len {
402403
if $array.is_null(i) {
403404
cast_array.append_null()
404-
} else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? {
405+
} else if let Some(cast_value) = parse_fn($array.value(i))? {
405406
cast_array.append_value(cast_value);
406407
} else {
407408
cast_array.append_null()
@@ -1473,22 +1474,70 @@ fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(
14731474
.downcast_ref::<GenericStringArray<OffsetSize>>()
14741475
.expect("cast_string_to_int expected a string array");
14751476

1476-
let cast_array: ArrayRef = match to_type {
1477-
DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?,
1478-
DataType::Int16 => {
1479-
cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)?
1480-
}
1481-
DataType::Int32 => {
1482-
cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)?
1483-
}
1484-
DataType::Int64 => {
1485-
cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)?
1486-
}
1487-
dt => unreachable!(
1488-
"{}",
1489-
format!("invalid integer type {dt} in cast from string")
1490-
),
1491-
};
1477+
// Select parse function once per batch based on eval_mode
1478+
let cast_array: ArrayRef =
1479+
match (to_type, eval_mode) {
1480+
(DataType::Int8, EvalMode::Legacy) => {
1481+
cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_legacy)?
1482+
}
1483+
(DataType::Int8, EvalMode::Ansi) => {
1484+
cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_ansi)?
1485+
}
1486+
(DataType::Int8, EvalMode::Try) => {
1487+
cast_utf8_to_int!(string_array, Int8Type, parse_string_to_i8_try)?
1488+
}
1489+
(DataType::Int16, EvalMode::Legacy) => {
1490+
cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_legacy)?
1491+
}
1492+
(DataType::Int16, EvalMode::Ansi) => {
1493+
cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_ansi)?
1494+
}
1495+
(DataType::Int16, EvalMode::Try) => {
1496+
cast_utf8_to_int!(string_array, Int16Type, parse_string_to_i16_try)?
1497+
}
1498+
(DataType::Int32, EvalMode::Legacy) => cast_utf8_to_int!(
1499+
string_array,
1500+
Int32Type,
1501+
|s| do_parse_string_to_int_legacy::<i32>(s, i32::MIN)
1502+
)?,
1503+
(DataType::Int32, EvalMode::Ansi) => {
1504+
cast_utf8_to_int!(string_array, Int32Type, |s| do_parse_string_to_int_ansi::<
1505+
i32,
1506+
>(
1507+
s, "INT", i32::MIN
1508+
))?
1509+
}
1510+
(DataType::Int32, EvalMode::Try) => {
1511+
cast_utf8_to_int!(
1512+
string_array,
1513+
Int32Type,
1514+
|s| do_parse_string_to_int_try::<i32>(s, i32::MIN)
1515+
)?
1516+
}
1517+
(DataType::Int64, EvalMode::Legacy) => cast_utf8_to_int!(
1518+
string_array,
1519+
Int64Type,
1520+
|s| do_parse_string_to_int_legacy::<i64>(s, i64::MIN)
1521+
)?,
1522+
(DataType::Int64, EvalMode::Ansi) => {
1523+
cast_utf8_to_int!(string_array, Int64Type, |s| do_parse_string_to_int_ansi::<
1524+
i64,
1525+
>(
1526+
s, "BIGINT", i64::MIN
1527+
))?
1528+
}
1529+
(DataType::Int64, EvalMode::Try) => {
1530+
cast_utf8_to_int!(
1531+
string_array,
1532+
Int64Type,
1533+
|s| do_parse_string_to_int_try::<i64>(s, i64::MIN)
1534+
)?
1535+
}
1536+
(dt, _) => unreachable!(
1537+
"{}",
1538+
format!("invalid integer type {dt} in cast from string")
1539+
),
1540+
};
14921541
Ok(cast_array)
14931542
}
14941543

@@ -1960,51 +2009,50 @@ fn spark_cast_nonintegral_numeric_to_integral(
19602009
}
19612010
}
19622011

1963-
/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte
1964-
fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i8>> {
1965-
Ok(cast_string_to_int_with_range_check(
1966-
str,
1967-
eval_mode,
1968-
"TINYINT",
1969-
i8::MIN as i32,
1970-
i8::MAX as i32,
1971-
)?
1972-
.map(|v| v as i8))
2012+
fn parse_string_to_i8_legacy(str: &str) -> SparkResult<Option<i8>> {
2013+
match do_parse_string_to_int_legacy::<i32>(str, i32::MIN)? {
2014+
None => Ok(None),
2015+
Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)),
2016+
_ => Ok(None),
2017+
}
19732018
}
19742019

1975-
/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort
1976-
fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i16>> {
1977-
Ok(cast_string_to_int_with_range_check(
1978-
str,
1979-
eval_mode,
1980-
"SMALLINT",
1981-
i16::MIN as i32,
1982-
i16::MAX as i32,
1983-
)?
1984-
.map(|v| v as i16))
2020+
fn parse_string_to_i8_ansi(str: &str) -> SparkResult<Option<i8>> {
2021+
match do_parse_string_to_int_ansi::<i32>(str, "TINYINT", i32::MIN)? {
2022+
None => Ok(None),
2023+
Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)),
2024+
_ => Err(invalid_value(str, "STRING", "TINYINT")),
2025+
}
19852026
}
19862027

1987-
/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper)
1988-
fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i32>> {
1989-
do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
2028+
fn parse_string_to_i8_try(str: &str) -> SparkResult<Option<i8>> {
2029+
match do_parse_string_to_int_try::<i32>(str, i32::MIN)? {
2030+
None => Ok(None),
2031+
Some(v) if v >= i8::MIN as i32 && v <= i8::MAX as i32 => Ok(Some(v as i8)),
2032+
_ => Ok(None),
2033+
}
19902034
}
19912035

1992-
/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper)
1993-
fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult<Option<i64>> {
1994-
do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
2036+
fn parse_string_to_i16_legacy(str: &str) -> SparkResult<Option<i16>> {
2037+
match do_parse_string_to_int_legacy::<i32>(str, i32::MIN)? {
2038+
None => Ok(None),
2039+
Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)),
2040+
_ => Ok(None),
2041+
}
19952042
}
19962043

1997-
fn cast_string_to_int_with_range_check(
1998-
str: &str,
1999-
eval_mode: EvalMode,
2000-
type_name: &str,
2001-
min: i32,
2002-
max: i32,
2003-
) -> SparkResult<Option<i32>> {
2004-
match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
2044+
fn parse_string_to_i16_ansi(str: &str) -> SparkResult<Option<i16>> {
2045+
match do_parse_string_to_int_ansi::<i32>(str, "SMALLINT", i32::MIN)? {
20052046
None => Ok(None),
2006-
Some(v) if v >= min && v <= max => Ok(Some(v)),
2007-
_ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
2047+
Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)),
2048+
_ => Err(invalid_value(str, "STRING", "SMALLINT")),
2049+
}
2050+
}
2051+
2052+
fn parse_string_to_i16_try(str: &str) -> SparkResult<Option<i16>> {
2053+
match do_parse_string_to_int_try::<i32>(str, i32::MIN)? {
2054+
None => Ok(None),
2055+
Some(v) if v >= i16::MIN as i32 && v <= i16::MAX as i32 => Ok(Some(v as i16)),
20082056
_ => Ok(None),
20092057
}
20102058
}
@@ -2210,19 +2258,6 @@ fn do_parse_string_to_int_try<T: Integer + CheckedSub + CheckedNeg + From<u8> +
22102258
Ok(Some(result))
22112259
}
22122260

2213-
fn do_cast_string_to_int<T: Integer + CheckedSub + CheckedNeg + From<u8> + Copy>(
2214-
str: &str,
2215-
eval_mode: EvalMode,
2216-
type_name: &str,
2217-
min_value: T,
2218-
) -> SparkResult<Option<T>> {
2219-
match eval_mode {
2220-
EvalMode::Legacy => do_parse_string_to_int_legacy(str, min_value),
2221-
EvalMode::Ansi => do_parse_string_to_int_ansi(str, type_name, min_value),
2222-
EvalMode::Try => do_parse_string_to_int_try(str, min_value),
2223-
}
2224-
}
2225-
22262261
fn cast_string_to_decimal(
22272262
array: &ArrayRef,
22282263
to_type: &DataType,

0 commit comments

Comments
 (0)