Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 177 additions & 5 deletions native/spark-expr/benches/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
// specific language governing permissions and limitations
// under the License.use arrow::array::{ArrayRef, BooleanBuilder, Int32Builder, RecordBatch, StringBuilder};

use arrow::array::builder::{Decimal128Builder, StringBuilder};
use arrow::array::{ArrayRef, RecordBatch};
use arrow::array::builder::{Decimal128Builder, Int64Builder, StringBuilder};
use arrow::array::{ArrayRef, Int64Array, RecordBatch};
use arrow::datatypes::SchemaRef;
use arrow::datatypes::{DataType, Field, Schema};
use criterion::{criterion_group, criterion_main, Criterion};
Expand All @@ -25,14 +25,14 @@ use datafusion::datasource::source::DataSourceExec;
use datafusion::execution::TaskContext;
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::logical_expr::AggregateUDF;
use datafusion::logical_expr::function::AccumulatorArgs;
use datafusion::logical_expr::{AggregateUDF, AggregateUDFImpl, EmitTo};
use datafusion::physical_expr::aggregate::AggregateExprBuilder;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::ExecutionPlan;
use datafusion_comet_spark_expr::SumDecimal;
use datafusion_comet_spark_expr::{AvgDecimal, EvalMode};
use datafusion_comet_spark_expr::{AvgDecimal, EvalMode, SumDecimal, SumInteger};
use futures::StreamExt;
use std::hint::black_box;
use std::sync::Arc;
Expand Down Expand Up @@ -111,6 +111,153 @@ fn criterion_benchmark(c: &mut Criterion) {
});

group.finish();

// SumInteger benchmarks
let mut group = c.benchmark_group("sum_integer");
let int_batch = create_int64_record_batch(num_rows);
let mut int_batches = Vec::new();
for _ in 0..10 {
int_batches.push(int_batch.clone());
}
let int_partitions = &[int_batches];
let int_c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
let int_c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));

group.bench_function("sum_int64_datafusion", |b| {
let datafusion_sum = sum_udaf();
b.to_async(&rt).iter(|| {
black_box(agg_test(
int_partitions,
int_c0.clone(),
int_c1.clone(),
datafusion_sum.clone(),
"sum",
))
})
});

group.bench_function("sum_int64_comet_legacy", |b| {
let comet_sum = Arc::new(AggregateUDF::new_from_impl(
SumInteger::try_new(DataType::Int64, EvalMode::Legacy).unwrap(),
));
b.to_async(&rt).iter(|| {
black_box(agg_test(
int_partitions,
int_c0.clone(),
int_c1.clone(),
comet_sum.clone(),
"sum",
))
})
});

group.bench_function("sum_int64_comet_ansi", |b| {
let comet_sum = Arc::new(AggregateUDF::new_from_impl(
SumInteger::try_new(DataType::Int64, EvalMode::Ansi).unwrap(),
));
b.to_async(&rt).iter(|| {
black_box(agg_test(
int_partitions,
int_c0.clone(),
int_c1.clone(),
comet_sum.clone(),
"sum",
))
})
});

group.bench_function("sum_int64_comet_try", |b| {
let comet_sum = Arc::new(AggregateUDF::new_from_impl(
SumInteger::try_new(DataType::Int64, EvalMode::Try).unwrap(),
));
b.to_async(&rt).iter(|| {
black_box(agg_test(
int_partitions,
int_c0.clone(),
int_c1.clone(),
comet_sum.clone(),
"sum",
))
})
});

group.finish();

// Direct accumulator benchmarks (bypassing execution framework)
let mut group = c.benchmark_group("sum_integer_accumulator");
let int64_array: ArrayRef = Arc::new(Int64Array::from_iter_values(0..8192i64));
let arrays: Vec<ArrayRef> = vec![int64_array];

let return_field = Arc::new(Field::new("sum", DataType::Int64, true));
let schema = Schema::new(vec![Field::new("c0", DataType::Int64, true)]);
let expr_field = Arc::new(Field::new("c0", DataType::Int64, true));
let expr_fields: Vec<Arc<Field>> = vec![expr_field];

// Single-row Accumulator benchmarks
for (name, eval_mode) in [
("row_legacy", EvalMode::Legacy),
("row_ansi", EvalMode::Ansi),
("row_try", EvalMode::Try),
] {
let return_field = return_field.clone();
let expr_fields = expr_fields.clone();
group.bench_function(name, |b| {
let udf = SumInteger::try_new(DataType::Int64, eval_mode).unwrap();
b.iter(|| {
let acc_args = AccumulatorArgs {
return_field: return_field.clone(),
schema: &schema,
ignore_nulls: false,
order_bys: &[],
name: "sum",
is_distinct: false,
is_reversed: false,
exprs: &[],
expr_fields: &expr_fields,
};
let mut acc = udf.accumulator(acc_args).unwrap();
for _ in 0..10 {
acc.update_batch(&arrays).unwrap();
}
black_box(acc.evaluate().unwrap())
})
});
}

// GroupsAccumulator benchmarks
let group_indices: Vec<usize> = (0..8192).map(|i| i % 1024).collect();
for (name, eval_mode) in [
("groups_legacy", EvalMode::Legacy),
("groups_ansi", EvalMode::Ansi),
("groups_try", EvalMode::Try),
] {
let return_field = return_field.clone();
let expr_fields = expr_fields.clone();
group.bench_function(name, |b| {
let udf = SumInteger::try_new(DataType::Int64, eval_mode).unwrap();
b.iter(|| {
let acc_args = AccumulatorArgs {
return_field: return_field.clone(),
schema: &schema,
ignore_nulls: false,
order_bys: &[],
name: "sum",
is_distinct: false,
is_reversed: false,
exprs: &[],
expr_fields: &expr_fields,
};
let mut acc = udf.create_groups_accumulator(acc_args).unwrap();
for _ in 0..10 {
acc.update_batch(&arrays, &group_indices, None, 1024)
.unwrap();
}
black_box(acc.evaluate(EmitTo::All).unwrap())
})
});
}

group.finish();
}

async fn agg_test(
Expand Down Expand Up @@ -187,6 +334,31 @@ fn create_record_batch(num_rows: usize) -> RecordBatch {
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
}

fn create_int64_record_batch(num_rows: usize) -> RecordBatch {
let mut int64_builder = Int64Builder::with_capacity(num_rows);
let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
for i in 0..num_rows {
int64_builder.append_value(i as i64);
string_builder.append_value(format!("group_{}", i % 1024));
}
let int64_array = Arc::new(int64_builder.finish());
let string_array = Arc::new(string_builder.finish());

let mut fields = vec![];
let mut columns: Vec<ArrayRef> = vec![];

// string column for grouping
fields.push(Field::new("c0", DataType::Utf8, false));
columns.push(string_array);

// int64 column for summing
fields.push(Field::new("c1", DataType::Int64, false));
columns.push(int64_array);

let schema = Schema::new(fields);
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
}

fn config() -> Criterion {
Criterion::default()
.measurement_time(Duration::from_millis(500))
Expand Down
Loading
Loading