Skip to content

Commit cb52636

Browse files
committed
feat: add criterion benchmark for wide decimal binary expr
Add benchmark comparing old Cast->BinaryExpr->Cast chain vs fused WideDecimalBinaryExpr for Decimal128 add/sub/mul. Covers four cases: add with same scale, add with different scales, multiply, and subtract.
1 parent fe9b1bd commit cb52636

2 files changed

Lines changed: 151 additions & 0 deletions

File tree

native/spark-expr/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,7 @@ path = "tests/spark_expr_reg.rs"
103103
[[bench]]
104104
name = "cast_from_boolean"
105105
harness = false
106+
107+
[[bench]]
108+
name = "wide_decimal"
109+
harness = false
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Benchmarks comparing the old Cast->BinaryExpr->Cast chain vs the fused WideDecimalBinaryExpr
19+
//! for Decimal128 arithmetic that requires wider intermediate precision.
20+
21+
use arrow::array::builder::Decimal128Builder;
22+
use arrow::array::RecordBatch;
23+
use arrow::datatypes::{DataType, Field, Schema};
24+
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
25+
use datafusion::logical_expr::Operator;
26+
use datafusion::physical_expr::expressions::{BinaryExpr, Column};
27+
use datafusion::physical_expr::PhysicalExpr;
28+
use datafusion_comet_spark_expr::{
29+
Cast, EvalMode, SparkCastOptions, WideDecimalBinaryExpr, WideDecimalOp,
30+
};
31+
use std::sync::Arc;
32+
33+
const BATCH_SIZE: usize = 8192;
34+
35+
/// Build a RecordBatch with two Decimal128 columns.
36+
fn make_decimal_batch(p1: u8, s1: i8, p2: u8, s2: i8) -> RecordBatch {
37+
let mut left = Decimal128Builder::new();
38+
let mut right = Decimal128Builder::new();
39+
for i in 0..BATCH_SIZE as i128 {
40+
left.append_value(123456789012345_i128 + i * 1000);
41+
right.append_value(987654321098765_i128 - i * 1000);
42+
}
43+
let left = left
44+
.finish()
45+
.with_data_type(DataType::Decimal128(p1, s1));
46+
let right = right
47+
.finish()
48+
.with_data_type(DataType::Decimal128(p2, s2));
49+
let schema = Schema::new(vec![
50+
Field::new("left", DataType::Decimal128(p1, s1), false),
51+
Field::new("right", DataType::Decimal128(p2, s2), false),
52+
]);
53+
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(left), Arc::new(right)]).unwrap()
54+
}
55+
56+
/// Old approach: Cast(Decimal128->Decimal256) both sides, BinaryExpr, Cast(Decimal256->Decimal128).
57+
fn build_old_expr(
58+
p1: u8,
59+
s1: i8,
60+
p2: u8,
61+
s2: i8,
62+
op: Operator,
63+
out_type: DataType,
64+
) -> Arc<dyn PhysicalExpr> {
65+
let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
66+
let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
67+
let cast_opts = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false);
68+
let left_cast = Arc::new(Cast::new(
69+
left_col,
70+
DataType::Decimal256(p1, s1),
71+
cast_opts.clone(),
72+
));
73+
let right_cast = Arc::new(Cast::new(
74+
right_col,
75+
DataType::Decimal256(p2, s2),
76+
cast_opts.clone(),
77+
));
78+
let binary = Arc::new(BinaryExpr::new(left_cast, op, right_cast));
79+
Arc::new(Cast::new(binary, out_type, cast_opts))
80+
}
81+
82+
/// New approach: single fused WideDecimalBinaryExpr.
83+
fn build_new_expr(op: WideDecimalOp, p_out: u8, s_out: i8) -> Arc<dyn PhysicalExpr> {
84+
let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
85+
let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
86+
Arc::new(WideDecimalBinaryExpr::new(
87+
left_col, right_col, op, p_out, s_out, EvalMode::Legacy,
88+
))
89+
}
90+
91+
fn bench_case(
92+
group: &mut criterion::BenchmarkGroup<criterion::measurement::WallTime>,
93+
name: &str,
94+
batch: &RecordBatch,
95+
old_expr: &Arc<dyn PhysicalExpr>,
96+
new_expr: &Arc<dyn PhysicalExpr>,
97+
) {
98+
group.bench_with_input(BenchmarkId::new("old", name), batch, |b, batch| {
99+
b.iter(|| old_expr.evaluate(batch).unwrap());
100+
});
101+
group.bench_with_input(BenchmarkId::new("fused", name), batch, |b, batch| {
102+
b.iter(|| new_expr.evaluate(batch).unwrap());
103+
});
104+
}
105+
106+
fn criterion_benchmark(c: &mut Criterion) {
107+
let mut group = c.benchmark_group("wide_decimal");
108+
109+
// Case 1: Add with same scale - Decimal128(38,10) + Decimal128(38,10) -> Decimal128(38,10)
110+
// Triggers wide path because max(s1,s2) + max(p1-s1, p2-s2) = 10 + 28 = 38 >= 38
111+
{
112+
let batch = make_decimal_batch(38, 10, 38, 10);
113+
let old = build_old_expr(38, 10, 38, 10, Operator::Plus, DataType::Decimal128(38, 10));
114+
let new = build_new_expr(WideDecimalOp::Add, 38, 10);
115+
bench_case(&mut group, "add_same_scale", &batch, &old, &new);
116+
}
117+
118+
// Case 2: Add with different scales - Decimal128(38,6) + Decimal128(38,4) -> Decimal128(38,6)
119+
{
120+
let batch = make_decimal_batch(38, 6, 38, 4);
121+
let old = build_old_expr(38, 6, 38, 4, Operator::Plus, DataType::Decimal128(38, 6));
122+
let new = build_new_expr(WideDecimalOp::Add, 38, 6);
123+
bench_case(&mut group, "add_diff_scale", &batch, &old, &new);
124+
}
125+
126+
// Case 3: Multiply - Decimal128(20,10) * Decimal128(20,10) -> Decimal128(38,6)
127+
// Triggers wide path because p1 + p2 = 40 >= 38
128+
{
129+
let batch = make_decimal_batch(20, 10, 20, 10);
130+
let old = build_old_expr(20, 10, 20, 10, Operator::Multiply, DataType::Decimal128(38, 6));
131+
let new = build_new_expr(WideDecimalOp::Multiply, 38, 6);
132+
bench_case(&mut group, "multiply", &batch, &old, &new);
133+
}
134+
135+
// Case 4: Subtract with same scale - Decimal128(38,18) - Decimal128(38,18) -> Decimal128(38,18)
136+
{
137+
let batch = make_decimal_batch(38, 18, 38, 18);
138+
let old = build_old_expr(38, 18, 38, 18, Operator::Minus, DataType::Decimal128(38, 18));
139+
let new = build_new_expr(WideDecimalOp::Subtract, 38, 18);
140+
bench_case(&mut group, "subtract", &batch, &old, &new);
141+
}
142+
143+
group.finish();
144+
}
145+
146+
criterion_group!(benches, criterion_benchmark);
147+
criterion_main!(benches);

0 commit comments

Comments
 (0)