Skip to content

Commit d7495bd

Browse files
committed
feat: add criterion benchmark for wide decimal binary expr
Add benchmark comparing old Cast->BinaryExpr->Cast chain vs fused WideDecimalBinaryExpr for Decimal128 add/sub/mul. Covers four cases: add with same scale, add with different scales, multiply, and subtract.
1 parent 4b3fd48 commit d7495bd

2 files changed

Lines changed: 166 additions & 0 deletions

File tree

native/spark-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,7 @@ path = "tests/spark_expr_reg.rs"
103103
[[bench]]
104104
name = "cast_from_boolean"
105105
harness = false
106+
107+
[[bench]]
108+
name = "wide_decimal"
109+
harness = false
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Benchmarks comparing the old Cast->BinaryExpr->Cast chain vs the fused WideDecimalBinaryExpr
19+
//! for Decimal128 arithmetic that requires wider intermediate precision.
20+
21+
use arrow::array::builder::Decimal128Builder;
22+
use arrow::array::RecordBatch;
23+
use arrow::datatypes::{DataType, Field, Schema};
24+
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
25+
use datafusion::logical_expr::Operator;
26+
use datafusion::physical_expr::expressions::{BinaryExpr, Column};
27+
use datafusion::physical_expr::PhysicalExpr;
28+
use datafusion_comet_spark_expr::{
29+
Cast, EvalMode, SparkCastOptions, WideDecimalBinaryExpr, WideDecimalOp,
30+
};
31+
use std::sync::Arc;
32+
33+
const BATCH_SIZE: usize = 8192;
34+
35+
/// Build a RecordBatch with two Decimal128 columns.
36+
fn make_decimal_batch(p1: u8, s1: i8, p2: u8, s2: i8) -> RecordBatch {
37+
let mut left = Decimal128Builder::new();
38+
let mut right = Decimal128Builder::new();
39+
for i in 0..BATCH_SIZE as i128 {
40+
left.append_value(123456789012345_i128 + i * 1000);
41+
right.append_value(987654321098765_i128 - i * 1000);
42+
}
43+
let left = left.finish().with_data_type(DataType::Decimal128(p1, s1));
44+
let right = right.finish().with_data_type(DataType::Decimal128(p2, s2));
45+
let schema = Schema::new(vec![
46+
Field::new("left", DataType::Decimal128(p1, s1), false),
47+
Field::new("right", DataType::Decimal128(p2, s2), false),
48+
]);
49+
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(left), Arc::new(right)]).unwrap()
50+
}
51+
52+
/// Old approach: Cast(Decimal128->Decimal256) both sides, BinaryExpr, Cast(Decimal256->Decimal128).
53+
fn build_old_expr(
54+
p1: u8,
55+
s1: i8,
56+
p2: u8,
57+
s2: i8,
58+
op: Operator,
59+
out_type: DataType,
60+
) -> Arc<dyn PhysicalExpr> {
61+
let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
62+
let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
63+
let cast_opts = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false);
64+
let left_cast = Arc::new(Cast::new(
65+
left_col,
66+
DataType::Decimal256(p1, s1),
67+
cast_opts.clone(),
68+
));
69+
let right_cast = Arc::new(Cast::new(
70+
right_col,
71+
DataType::Decimal256(p2, s2),
72+
cast_opts.clone(),
73+
));
74+
let binary = Arc::new(BinaryExpr::new(left_cast, op, right_cast));
75+
Arc::new(Cast::new(binary, out_type, cast_opts))
76+
}
77+
78+
/// New approach: single fused WideDecimalBinaryExpr.
79+
fn build_new_expr(op: WideDecimalOp, p_out: u8, s_out: i8) -> Arc<dyn PhysicalExpr> {
80+
let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
81+
let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
82+
Arc::new(WideDecimalBinaryExpr::new(
83+
left_col,
84+
right_col,
85+
op,
86+
p_out,
87+
s_out,
88+
EvalMode::Legacy,
89+
))
90+
}
91+
92+
fn bench_case(
93+
group: &mut criterion::BenchmarkGroup<criterion::measurement::WallTime>,
94+
name: &str,
95+
batch: &RecordBatch,
96+
old_expr: &Arc<dyn PhysicalExpr>,
97+
new_expr: &Arc<dyn PhysicalExpr>,
98+
) {
99+
group.bench_with_input(BenchmarkId::new("old", name), batch, |b, batch| {
100+
b.iter(|| old_expr.evaluate(batch).unwrap());
101+
});
102+
group.bench_with_input(BenchmarkId::new("fused", name), batch, |b, batch| {
103+
b.iter(|| new_expr.evaluate(batch).unwrap());
104+
});
105+
}
106+
107+
fn criterion_benchmark(c: &mut Criterion) {
108+
let mut group = c.benchmark_group("wide_decimal");
109+
110+
// Case 1: Add with same scale - Decimal128(38,10) + Decimal128(38,10) -> Decimal128(38,10)
111+
// Triggers wide path because max(s1,s2) + max(p1-s1, p2-s2) = 10 + 28 = 38 >= 38
112+
{
113+
let batch = make_decimal_batch(38, 10, 38, 10);
114+
let old = build_old_expr(38, 10, 38, 10, Operator::Plus, DataType::Decimal128(38, 10));
115+
let new = build_new_expr(WideDecimalOp::Add, 38, 10);
116+
bench_case(&mut group, "add_same_scale", &batch, &old, &new);
117+
}
118+
119+
// Case 2: Add with different scales - Decimal128(38,6) + Decimal128(38,4) -> Decimal128(38,6)
120+
{
121+
let batch = make_decimal_batch(38, 6, 38, 4);
122+
let old = build_old_expr(38, 6, 38, 4, Operator::Plus, DataType::Decimal128(38, 6));
123+
let new = build_new_expr(WideDecimalOp::Add, 38, 6);
124+
bench_case(&mut group, "add_diff_scale", &batch, &old, &new);
125+
}
126+
127+
// Case 3: Multiply - Decimal128(20,10) * Decimal128(20,10) -> Decimal128(38,6)
128+
// Triggers wide path because p1 + p2 = 40 >= 38
129+
{
130+
let batch = make_decimal_batch(20, 10, 20, 10);
131+
let old = build_old_expr(
132+
20,
133+
10,
134+
20,
135+
10,
136+
Operator::Multiply,
137+
DataType::Decimal128(38, 6),
138+
);
139+
let new = build_new_expr(WideDecimalOp::Multiply, 38, 6);
140+
bench_case(&mut group, "multiply", &batch, &old, &new);
141+
}
142+
143+
// Case 4: Subtract with same scale - Decimal128(38,18) - Decimal128(38,18) -> Decimal128(38,18)
144+
{
145+
let batch = make_decimal_batch(38, 18, 38, 18);
146+
let old = build_old_expr(
147+
38,
148+
18,
149+
38,
150+
18,
151+
Operator::Minus,
152+
DataType::Decimal128(38, 18),
153+
);
154+
let new = build_new_expr(WideDecimalOp::Subtract, 38, 18);
155+
bench_case(&mut group, "subtract", &batch, &old, &new);
156+
}
157+
158+
group.finish();
159+
}
160+
161+
criterion_group!(benches, criterion_benchmark);
162+
criterion_main!(benches);

0 commit comments

Comments
 (0)