Skip to content

Commit cf2b2a4

Browse files
committed
improve performance of checkoverflow
1 parent dca45ea commit cf2b2a4

3 files changed

Lines changed: 239 additions & 31 deletions

File tree

native/spark-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ harness = false
8080
name = "padding"
8181
harness = false
8282

83+
[[bench]]
84+
name = "check_overflow"
85+
harness = false
86+
8387
[[test]]
8488
name = "test_udf_registration"
8589
path = "tests/spark_expr_reg.rs"
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::builder::Decimal128Builder;
19+
use arrow::array::RecordBatch;
20+
use arrow::datatypes::{DataType, Field, Schema};
21+
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
22+
use datafusion::physical_expr::PhysicalExpr;
23+
use datafusion_comet_spark_expr::CheckOverflow;
24+
use std::sync::Arc;
25+
26+
fn create_decimal_batch(size: usize, precision: u8, scale: i8, with_nulls: bool) -> RecordBatch {
27+
let schema = Arc::new(Schema::new(vec![Field::new(
28+
"a",
29+
DataType::Decimal128(precision, scale),
30+
true,
31+
)]));
32+
let mut builder = Decimal128Builder::with_capacity(size);
33+
34+
for i in 0..size {
35+
if with_nulls && i % 10 == 0 {
36+
builder.append_null();
37+
} else {
38+
// Values that fit within precision 10 (max ~9999999999)
39+
builder.append_value((i as i128) * 12345);
40+
}
41+
}
42+
43+
let array = builder
44+
.finish()
45+
.with_precision_and_scale(precision, scale)
46+
.unwrap();
47+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
48+
}
49+
50+
fn create_batch_with_overflow(
51+
size: usize,
52+
input_precision: u8,
53+
target_precision: u8,
54+
scale: i8,
55+
) -> RecordBatch {
56+
let schema = Arc::new(Schema::new(vec![Field::new(
57+
"a",
58+
DataType::Decimal128(input_precision, scale),
59+
true,
60+
)]));
61+
let mut builder = Decimal128Builder::with_capacity(size);
62+
63+
// Create values where ~10% will overflow the target precision
64+
let max_for_target = 10i128.pow(target_precision as u32) - 1;
65+
for i in 0..size {
66+
if i % 10 == 0 {
67+
// This value will overflow target precision
68+
builder.append_value(max_for_target + (i as i128) + 1);
69+
} else {
70+
// This value is within target precision
71+
builder.append_value((i as i128) % max_for_target);
72+
}
73+
}
74+
75+
let array = builder
76+
.finish()
77+
.with_precision_and_scale(input_precision, scale)
78+
.unwrap();
79+
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
80+
}
81+
82+
fn criterion_benchmark(c: &mut Criterion) {
83+
let sizes = [1000, 10000];
84+
85+
let mut group = c.benchmark_group("check_overflow");
86+
87+
for size in sizes {
88+
// Benchmark: No overflow possible (precision already fits)
89+
// This tests the fast path where input precision <= target precision
90+
let batch_no_overflow = create_decimal_batch(size, 10, 2, false);
91+
92+
// Create CheckOverflow that goes from precision 10 to 18 (no overflow possible)
93+
let check_overflow_no_validation = Arc::new(CheckOverflow::new(
94+
Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)),
95+
DataType::Decimal128(18, 2), // larger precision = no overflow possible
96+
false,
97+
));
98+
99+
group.bench_with_input(
100+
BenchmarkId::new("no_overflow_possible", size),
101+
&batch_no_overflow,
102+
|b, batch| {
103+
b.iter(|| check_overflow_no_validation.evaluate(batch).unwrap());
104+
},
105+
);
106+
107+
// Benchmark: Validation needed, but no overflows occur
108+
let batch_valid = create_decimal_batch(size, 18, 2, true);
109+
let check_overflow_valid = Arc::new(CheckOverflow::new(
110+
Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)),
111+
DataType::Decimal128(10, 2), // smaller precision, need to validate
112+
false,
113+
));
114+
115+
group.bench_with_input(
116+
BenchmarkId::new("validation_no_overflow", size),
117+
&batch_valid,
118+
|b, batch| {
119+
b.iter(|| check_overflow_valid.evaluate(batch).unwrap());
120+
},
121+
);
122+
123+
// Benchmark: With ~10% overflows (requires null insertion)
124+
let batch_with_overflow = create_batch_with_overflow(size, 18, 8, 2);
125+
let check_overflow_with_nulls = Arc::new(CheckOverflow::new(
126+
Arc::new(datafusion::physical_expr::expressions::Column::new("a", 0)),
127+
DataType::Decimal128(8, 2),
128+
false,
129+
));
130+
131+
group.bench_with_input(
132+
BenchmarkId::new("with_overflow_to_null", size),
133+
&batch_with_overflow,
134+
|b, batch| {
135+
b.iter(|| check_overflow_with_nulls.evaluate(batch).unwrap());
136+
},
137+
);
138+
}
139+
140+
group.finish();
141+
}
142+
143+
fn config() -> Criterion {
144+
Criterion::default()
145+
}
146+
147+
criterion_group! {
148+
name = benches;
149+
config = config();
150+
targets = criterion_benchmark
151+
}
152+
criterion_main!(benches);

native/spark-expr/src/math_funcs/internal/checkoverflow.rs

Lines changed: 83 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::utils::is_valid_decimal_precision;
1819
use arrow::datatypes::{DataType, Schema};
1920
use arrow::{
20-
array::{as_primitive_array, Array, ArrayRef, Decimal128Array},
21-
datatypes::{Decimal128Type, DecimalType},
21+
array::{as_primitive_array, Array, Decimal128Array},
22+
datatypes::Decimal128Type,
2223
record_batch::RecordBatch,
2324
};
2425
use datafusion::common::{DataFusionError, ScalarValue};
@@ -101,8 +102,8 @@ impl PhysicalExpr for CheckOverflow {
101102
ColumnarValue::Array(array)
102103
if matches!(array.data_type(), DataType::Decimal128(_, _)) =>
103104
{
104-
let (precision, scale) = match &self.data_type {
105-
DataType::Decimal128(p, s) => (p, s),
105+
let (target_precision, target_scale) = match &self.data_type {
106+
DataType::Decimal128(p, s) => (*p, *s),
106107
dt => {
107108
return Err(DataFusionError::Execution(format!(
108109
"CheckOverflow expects only Decimal128, but got {dt:?}"
@@ -112,38 +113,89 @@ impl PhysicalExpr for CheckOverflow {
112113

113114
let decimal_array = as_primitive_array::<Decimal128Type>(&array);
114115

115-
let casted_array = if self.fail_on_error {
116-
// Returning error if overflow
117-
decimal_array.validate_decimal_precision(*precision)?;
116+
// Get input precision to check if we can skip validation
117+
let (input_precision, input_scale) = match decimal_array.data_type() {
118+
DataType::Decimal128(p, s) => (*p, *s),
119+
_ => unreachable!(),
120+
};
121+
122+
// Optimization: if input precision <= target precision and scales match,
123+
// no overflow is possible - just update metadata
124+
if input_precision <= target_precision && input_scale == target_scale {
125+
let new_array = decimal_array
126+
.clone()
127+
.with_precision_and_scale(target_precision, target_scale)?;
128+
return Ok(ColumnarValue::Array(Arc::new(new_array)));
129+
}
130+
131+
let result_array = if self.fail_on_error {
132+
// ANSI mode: validate and return error on overflow
133+
// Use optimized validation that avoids error string allocation until needed
134+
for i in 0..decimal_array.len() {
135+
if decimal_array.is_valid(i) {
136+
let value = decimal_array.value(i);
137+
if !is_valid_decimal_precision(value, target_precision) {
138+
return Err(DataFusionError::ArrowError(
139+
Box::new(arrow::error::ArrowError::InvalidArgumentError(
140+
format!(
141+
"{} is not a valid Decimal128 value with precision {}",
142+
value, target_precision
143+
),
144+
)),
145+
None,
146+
));
147+
}
148+
}
149+
}
150+
// Validation passed - just update metadata without copying data
118151
decimal_array
152+
.clone()
153+
.with_precision_and_scale(target_precision, target_scale)?
119154
} else {
120-
// Overflowing gets null value
121-
&decimal_array.null_if_overflow_precision(*precision)
155+
// Legacy/Try mode: convert overflows to null
156+
// Use Arrow's optimized null_if_overflow_precision which does a single pass
157+
let result = decimal_array.null_if_overflow_precision(target_precision);
158+
result.with_precision_and_scale(target_precision, target_scale)?
122159
};
123160

124-
let new_array = Decimal128Array::from(casted_array.into_data())
125-
.with_precision_and_scale(*precision, *scale)
126-
.map(|a| Arc::new(a) as ArrayRef)?;
127-
128-
Ok(ColumnarValue::Array(new_array))
161+
Ok(ColumnarValue::Array(Arc::new(result_array)))
129162
}
130-
ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => {
131-
// `fail_on_error` is only true when ANSI is enabled, which we don't support yet
132-
// (Java side will simply fallback to Spark when it is enabled)
133-
assert!(
134-
!self.fail_on_error,
135-
"fail_on_error (ANSI mode) is not supported yet"
136-
);
137-
138-
let new_v: Option<i128> = v.and_then(|v| {
139-
Decimal128Type::validate_decimal_precision(v, precision, scale)
140-
.map(|_| v)
141-
.ok()
142-
});
143-
144-
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
145-
new_v, precision, scale,
146-
)))
163+
ColumnarValue::Scalar(ScalarValue::Decimal128(v, _, _)) => {
164+
let (target_precision, target_scale) = match &self.data_type {
165+
DataType::Decimal128(p, s) => (*p, *s),
166+
dt => {
167+
return Err(DataFusionError::Execution(format!(
168+
"CheckOverflow expects only Decimal128 for scalar, but got {dt:?}"
169+
)))
170+
}
171+
};
172+
173+
if self.fail_on_error {
174+
if let Some(value) = v {
175+
if !is_valid_decimal_precision(value, target_precision) {
176+
return Err(DataFusionError::ArrowError(
177+
Box::new(arrow::error::ArrowError::InvalidArgumentError(format!(
178+
"{} is not a valid Decimal128 value with precision {}",
179+
value, target_precision
180+
))),
181+
None,
182+
));
183+
}
184+
}
185+
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
186+
v,
187+
target_precision,
188+
target_scale,
189+
)))
190+
} else {
191+
// Use optimized bool check instead of Result-returning validation
192+
let new_v = v.filter(|&val| is_valid_decimal_precision(val, target_precision));
193+
Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
194+
new_v,
195+
target_precision,
196+
target_scale,
197+
)))
198+
}
147199
}
148200
v => Err(DataFusionError::Execution(format!(
149201
"CheckOverflow's child expression should be decimal array, but found {v:?}"

0 commit comments

Comments
 (0)