Skip to content

Commit fd47751

Browse files
authored
perf: refactor sum int with specialized implementations for each eval_mode (#3054)
1 parent d2fbd6e commit fd47751

2 files changed

Lines changed: 759 additions & 274 deletions

File tree

native/spark-expr/benches/aggregate.rs

Lines changed: 177 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.use arrow::array::{ArrayRef, BooleanBuilder, Int32Builder, RecordBatch, StringBuilder};
1717

18-
use arrow::array::builder::{Decimal128Builder, StringBuilder};
19-
use arrow::array::{ArrayRef, RecordBatch};
18+
use arrow::array::builder::{Decimal128Builder, Int64Builder, StringBuilder};
19+
use arrow::array::{ArrayRef, Int64Array, RecordBatch};
2020
use arrow::datatypes::SchemaRef;
2121
use arrow::datatypes::{DataType, Field, Schema};
2222
use criterion::{criterion_group, criterion_main, Criterion};
@@ -25,14 +25,14 @@ use datafusion::datasource::source::DataSourceExec;
2525
use datafusion::execution::TaskContext;
2626
use datafusion::functions_aggregate::average::avg_udaf;
2727
use datafusion::functions_aggregate::sum::sum_udaf;
28-
use datafusion::logical_expr::AggregateUDF;
28+
use datafusion::logical_expr::function::AccumulatorArgs;
29+
use datafusion::logical_expr::{AggregateUDF, AggregateUDFImpl, EmitTo};
2930
use datafusion::physical_expr::aggregate::AggregateExprBuilder;
3031
use datafusion::physical_expr::expressions::Column;
3132
use datafusion::physical_expr::PhysicalExpr;
3233
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
3334
use datafusion::physical_plan::ExecutionPlan;
34-
use datafusion_comet_spark_expr::SumDecimal;
35-
use datafusion_comet_spark_expr::{AvgDecimal, EvalMode};
35+
use datafusion_comet_spark_expr::{AvgDecimal, EvalMode, SumDecimal, SumInteger};
3636
use futures::StreamExt;
3737
use std::hint::black_box;
3838
use std::sync::Arc;
@@ -111,6 +111,153 @@ fn criterion_benchmark(c: &mut Criterion) {
111111
});
112112

113113
group.finish();
114+
115+
// SumInteger benchmarks
116+
let mut group = c.benchmark_group("sum_integer");
117+
let int_batch = create_int64_record_batch(num_rows);
118+
let mut int_batches = Vec::new();
119+
for _ in 0..10 {
120+
int_batches.push(int_batch.clone());
121+
}
122+
let int_partitions = &[int_batches];
123+
let int_c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
124+
let int_c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));
125+
126+
group.bench_function("sum_int64_datafusion", |b| {
127+
let datafusion_sum = sum_udaf();
128+
b.to_async(&rt).iter(|| {
129+
black_box(agg_test(
130+
int_partitions,
131+
int_c0.clone(),
132+
int_c1.clone(),
133+
datafusion_sum.clone(),
134+
"sum",
135+
))
136+
})
137+
});
138+
139+
group.bench_function("sum_int64_comet_legacy", |b| {
140+
let comet_sum = Arc::new(AggregateUDF::new_from_impl(
141+
SumInteger::try_new(DataType::Int64, EvalMode::Legacy).unwrap(),
142+
));
143+
b.to_async(&rt).iter(|| {
144+
black_box(agg_test(
145+
int_partitions,
146+
int_c0.clone(),
147+
int_c1.clone(),
148+
comet_sum.clone(),
149+
"sum",
150+
))
151+
})
152+
});
153+
154+
group.bench_function("sum_int64_comet_ansi", |b| {
155+
let comet_sum = Arc::new(AggregateUDF::new_from_impl(
156+
SumInteger::try_new(DataType::Int64, EvalMode::Ansi).unwrap(),
157+
));
158+
b.to_async(&rt).iter(|| {
159+
black_box(agg_test(
160+
int_partitions,
161+
int_c0.clone(),
162+
int_c1.clone(),
163+
comet_sum.clone(),
164+
"sum",
165+
))
166+
})
167+
});
168+
169+
group.bench_function("sum_int64_comet_try", |b| {
170+
let comet_sum = Arc::new(AggregateUDF::new_from_impl(
171+
SumInteger::try_new(DataType::Int64, EvalMode::Try).unwrap(),
172+
));
173+
b.to_async(&rt).iter(|| {
174+
black_box(agg_test(
175+
int_partitions,
176+
int_c0.clone(),
177+
int_c1.clone(),
178+
comet_sum.clone(),
179+
"sum",
180+
))
181+
})
182+
});
183+
184+
group.finish();
185+
186+
// Direct accumulator benchmarks (bypassing execution framework)
187+
let mut group = c.benchmark_group("sum_integer_accumulator");
188+
let int64_array: ArrayRef = Arc::new(Int64Array::from_iter_values(0..8192i64));
189+
let arrays: Vec<ArrayRef> = vec![int64_array];
190+
191+
let return_field = Arc::new(Field::new("sum", DataType::Int64, true));
192+
let schema = Schema::new(vec![Field::new("c0", DataType::Int64, true)]);
193+
let expr_field = Arc::new(Field::new("c0", DataType::Int64, true));
194+
let expr_fields: Vec<Arc<Field>> = vec![expr_field];
195+
196+
// Single-row Accumulator benchmarks
197+
for (name, eval_mode) in [
198+
("row_legacy", EvalMode::Legacy),
199+
("row_ansi", EvalMode::Ansi),
200+
("row_try", EvalMode::Try),
201+
] {
202+
let return_field = return_field.clone();
203+
let expr_fields = expr_fields.clone();
204+
group.bench_function(name, |b| {
205+
let udf = SumInteger::try_new(DataType::Int64, eval_mode).unwrap();
206+
b.iter(|| {
207+
let acc_args = AccumulatorArgs {
208+
return_field: return_field.clone(),
209+
schema: &schema,
210+
ignore_nulls: false,
211+
order_bys: &[],
212+
name: "sum",
213+
is_distinct: false,
214+
is_reversed: false,
215+
exprs: &[],
216+
expr_fields: &expr_fields,
217+
};
218+
let mut acc = udf.accumulator(acc_args).unwrap();
219+
for _ in 0..10 {
220+
acc.update_batch(&arrays).unwrap();
221+
}
222+
black_box(acc.evaluate().unwrap())
223+
})
224+
});
225+
}
226+
227+
// GroupsAccumulator benchmarks
228+
let group_indices: Vec<usize> = (0..8192).map(|i| i % 1024).collect();
229+
for (name, eval_mode) in [
230+
("groups_legacy", EvalMode::Legacy),
231+
("groups_ansi", EvalMode::Ansi),
232+
("groups_try", EvalMode::Try),
233+
] {
234+
let return_field = return_field.clone();
235+
let expr_fields = expr_fields.clone();
236+
group.bench_function(name, |b| {
237+
let udf = SumInteger::try_new(DataType::Int64, eval_mode).unwrap();
238+
b.iter(|| {
239+
let acc_args = AccumulatorArgs {
240+
return_field: return_field.clone(),
241+
schema: &schema,
242+
ignore_nulls: false,
243+
order_bys: &[],
244+
name: "sum",
245+
is_distinct: false,
246+
is_reversed: false,
247+
exprs: &[],
248+
expr_fields: &expr_fields,
249+
};
250+
let mut acc = udf.create_groups_accumulator(acc_args).unwrap();
251+
for _ in 0..10 {
252+
acc.update_batch(&arrays, &group_indices, None, 1024)
253+
.unwrap();
254+
}
255+
black_box(acc.evaluate(EmitTo::All).unwrap())
256+
})
257+
});
258+
}
259+
260+
group.finish();
114261
}
115262

116263
async fn agg_test(
@@ -187,6 +334,31 @@ fn create_record_batch(num_rows: usize) -> RecordBatch {
187334
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
188335
}
189336

337+
fn create_int64_record_batch(num_rows: usize) -> RecordBatch {
338+
let mut int64_builder = Int64Builder::with_capacity(num_rows);
339+
let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
340+
for i in 0..num_rows {
341+
int64_builder.append_value(i as i64);
342+
string_builder.append_value(format!("group_{}", i % 1024));
343+
}
344+
let int64_array = Arc::new(int64_builder.finish());
345+
let string_array = Arc::new(string_builder.finish());
346+
347+
let mut fields = vec![];
348+
let mut columns: Vec<ArrayRef> = vec![];
349+
350+
// string column for grouping
351+
fields.push(Field::new("c0", DataType::Utf8, false));
352+
columns.push(string_array);
353+
354+
// int64 column for summing
355+
fields.push(Field::new("c1", DataType::Int64, false));
356+
columns.push(int64_array);
357+
358+
let schema = Schema::new(fields);
359+
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
360+
}
361+
190362
fn config() -> Criterion {
191363
Criterion::default()
192364
.measurement_time(Duration::from_millis(500))

0 commit comments

Comments
 (0)