Skip to content

Commit 0aab6a3

Browse files
authored
feat: support SELECT DISTINCT id FROM t ORDER BY id LIMIT n query use GroupedTopKAggregateStream (#19653)
## Which issue does this PR close? - close #19638 ## Rationale for this change see issue #19638 ## What changes are included in this PR? 1. Introduced `LimitOptions` struct limit field with both `limit` and optional `descending` ordering direction 2. Extended `TopKAggregation` optimizer rule to DISTINCT queries by recognizing `GROUP BY` queries without aggregates and setting the `descending` flag based on ordering direction 3. Enhanced `GroupedTopKAggregateStream` to handle DISTINCT by using group key as both priority queue key and value for DISTINCT operations 4. Updated Proto definitions to add optional `descending` field to `AggLimit` message for serialization/deserialization ## benchmark result <img width="731" height="475" alt="image" src="https://github.com/user-attachments/assets/05b6eb8c-186d-4b17-84a9-a2897dbcb095" /> ## Are these changes tested? yes, add test case in aggregates_topk.slt ## Are there any user-facing changes? no
1 parent 3ea21aa commit 0aab6a3

File tree

16 files changed

+489
-66
lines changed

16 files changed

+489
-66
lines changed

datafusion/core/benches/topk_aggregate.rs

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,63 @@
1717

1818
mod data_utils;
1919

20+
use arrow::array::Int64Builder;
21+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
22+
use arrow::record_batch::RecordBatch;
2023
use arrow::util::pretty::pretty_format_batches;
2124
use criterion::{Criterion, criterion_group, criterion_main};
2225
use data_utils::make_data;
2326
use datafusion::physical_plan::{collect, displayable};
2427
use datafusion::prelude::SessionContext;
2528
use datafusion::{datasource::MemTable, error::Result};
2629
use datafusion_execution::config::SessionConfig;
30+
use rand::SeedableRng;
31+
use rand::seq::SliceRandom;
2732
use std::hint::black_box;
2833
use std::sync::Arc;
2934
use tokio::runtime::Runtime;
3035

3136
const LIMIT: usize = 10;
3237

38+
/// Create deterministic data for DISTINCT benchmarks with predictable trace_ids
39+
/// This ensures consistent results across benchmark runs
40+
fn make_distinct_data(
41+
partition_cnt: i32,
42+
sample_cnt: i32,
43+
) -> Result<(Arc<Schema>, Vec<Vec<RecordBatch>>)> {
44+
let mut rng = rand::rngs::SmallRng::from_seed([42; 32]);
45+
let total_samples = partition_cnt as usize * sample_cnt as usize;
46+
let mut ids = Vec::new();
47+
for i in 0..total_samples {
48+
ids.push(i as i64);
49+
}
50+
ids.shuffle(&mut rng);
51+
52+
let mut global_idx = 0;
53+
let schema = test_distinct_schema();
54+
let mut partitions = vec![];
55+
for _ in 0..partition_cnt {
56+
let mut id_builder = Int64Builder::new();
57+
58+
for _ in 0..sample_cnt {
59+
let id = ids[global_idx];
60+
id_builder.append_value(id);
61+
global_idx += 1;
62+
}
63+
64+
let id_col = Arc::new(id_builder.finish());
65+
let batch = RecordBatch::try_new(schema.clone(), vec![id_col])?;
66+
partitions.push(vec![batch]);
67+
}
68+
69+
Ok((schema, partitions))
70+
}
71+
72+
/// Returns a Schema for distinct benchmarks with i64 trace_id
73+
fn test_distinct_schema() -> SchemaRef {
74+
Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]))
75+
}
76+
3377
async fn create_context(
3478
partition_cnt: i32,
3579
sample_cnt: i32,
@@ -50,6 +94,25 @@ async fn create_context(
5094
Ok(ctx)
5195
}
5296

97+
async fn create_context_distinct(
98+
partition_cnt: i32,
99+
sample_cnt: i32,
100+
use_topk: bool,
101+
) -> Result<SessionContext> {
102+
// Use deterministic data generation for DISTINCT queries to ensure consistent results
103+
let (schema, parts) = make_distinct_data(partition_cnt, sample_cnt).unwrap();
104+
let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap());
105+
106+
// Create the DataFrame
107+
let mut cfg = SessionConfig::new();
108+
let opts = cfg.options_mut();
109+
opts.optimizer.enable_topk_aggregation = use_topk;
110+
let ctx = SessionContext::new_with_config(cfg);
111+
let _ = ctx.register_table("traces", mem_table)?;
112+
113+
Ok(ctx)
114+
}
115+
53116
fn run(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool, asc: bool) {
54117
black_box(rt.block_on(async { aggregate(ctx, limit, use_topk, asc).await })).unwrap();
55118
}
@@ -59,6 +122,17 @@ fn run_string(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool) {
59122
.unwrap();
60123
}
61124

125+
fn run_distinct(
126+
rt: &Runtime,
127+
ctx: SessionContext,
128+
limit: usize,
129+
use_topk: bool,
130+
asc: bool,
131+
) {
132+
black_box(rt.block_on(async { aggregate_distinct(ctx, limit, use_topk, asc).await }))
133+
.unwrap();
134+
}
135+
62136
async fn aggregate(
63137
ctx: SessionContext,
64138
limit: usize,
@@ -133,6 +207,84 @@ async fn aggregate_string(
133207
Ok(())
134208
}
135209

210+
async fn aggregate_distinct(
211+
ctx: SessionContext,
212+
limit: usize,
213+
use_topk: bool,
214+
asc: bool,
215+
) -> Result<()> {
216+
let order_direction = if asc { "asc" } else { "desc" };
217+
let sql = format!(
218+
"select id from traces group by id order by id {order_direction} limit {limit};"
219+
);
220+
let df = ctx.sql(sql.as_str()).await?;
221+
let plan = df.create_physical_plan().await?;
222+
let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string();
223+
assert_eq!(
224+
actual_phys_plan.contains(&format!("lim=[{limit}]")),
225+
use_topk
226+
);
227+
let batches = collect(plan, ctx.task_ctx()).await?;
228+
assert_eq!(batches.len(), 1);
229+
let batch = batches.first().unwrap();
230+
assert_eq!(batch.num_rows(), LIMIT);
231+
232+
let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase();
233+
234+
let expected_asc = r#"
235+
+----+
236+
| id |
237+
+----+
238+
| 0 |
239+
| 1 |
240+
| 2 |
241+
| 3 |
242+
| 4 |
243+
| 5 |
244+
| 6 |
245+
| 7 |
246+
| 8 |
247+
| 9 |
248+
+----+
249+
"#
250+
.trim();
251+
252+
let expected_desc = r#"
253+
+---------+
254+
| id |
255+
+---------+
256+
| 9999999 |
257+
| 9999998 |
258+
| 9999997 |
259+
| 9999996 |
260+
| 9999995 |
261+
| 9999994 |
262+
| 9999993 |
263+
| 9999992 |
264+
| 9999991 |
265+
| 9999990 |
266+
+---------+
267+
"#
268+
.trim();
269+
270+
// Verify exact results match expected values
271+
if asc {
272+
assert_eq!(
273+
actual.trim(),
274+
expected_asc,
275+
"Ascending DISTINCT results do not match expected values"
276+
);
277+
} else {
278+
assert_eq!(
279+
actual.trim(),
280+
expected_desc,
281+
"Descending DISTINCT results do not match expected values"
282+
);
283+
}
284+
285+
Ok(())
286+
}
287+
136288
fn criterion_benchmark(c: &mut Criterion) {
137289
let rt = Runtime::new().unwrap();
138290
let limit = LIMIT;
@@ -253,6 +405,37 @@ fn criterion_benchmark(c: &mut Criterion) {
253405
.as_str(),
254406
|b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)),
255407
);
408+
409+
// DISTINCT benchmarks
410+
let ctx = rt.block_on(async {
411+
create_context_distinct(partitions, samples, false)
412+
.await
413+
.unwrap()
414+
});
415+
c.bench_function(
416+
format!("distinct {} rows desc [no TopK]", partitions * samples).as_str(),
417+
|b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, false)),
418+
);
419+
420+
c.bench_function(
421+
format!("distinct {} rows asc [no TopK]", partitions * samples).as_str(),
422+
|b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, true)),
423+
);
424+
425+
let ctx_topk = rt.block_on(async {
426+
create_context_distinct(partitions, samples, true)
427+
.await
428+
.unwrap()
429+
});
430+
c.bench_function(
431+
format!("distinct {} rows desc [TopK]", partitions * samples).as_str(),
432+
|b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, false)),
433+
);
434+
435+
c.bench_function(
436+
format!("distinct {} rows asc [TopK]", partitions * samples).as_str(),
437+
|b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, true)),
438+
);
256439
}
257440

258441
criterion_group!(benches, criterion_benchmark);

datafusion/core/tests/execution/coop.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use datafusion::physical_expr::aggregate::AggregateExprBuilder;
2424
use datafusion::physical_plan;
2525
use datafusion::physical_plan::ExecutionPlan;
2626
use datafusion::physical_plan::aggregates::{
27-
AggregateExec, AggregateMode, PhysicalGroupBy,
27+
AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy,
2828
};
2929
use datafusion::physical_plan::execution_plan::Boundedness;
3030
use datafusion::prelude::SessionContext;
@@ -233,6 +233,7 @@ async fn agg_grouped_topk_yields(
233233
#[values(false, true)] pretend_infinite: bool,
234234
) -> Result<(), Box<dyn Error>> {
235235
// build session
236+
236237
let session_ctx = SessionContext::new();
237238

238239
// set up a top-k aggregation
@@ -260,7 +261,7 @@ async fn agg_grouped_topk_yields(
260261
inf.clone(),
261262
inf.schema(),
262263
)?
263-
.with_limit(Some(100)),
264+
.with_limit_options(Some(LimitOptions::new(100))),
264265
);
265266

266267
query_yields(aggr, session_ctx.task_ctx()).await

datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use datafusion_physical_optimizer::PhysicalOptimizerRule;
3737
use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate;
3838
use datafusion_physical_plan::ExecutionPlan;
3939
use datafusion_physical_plan::aggregates::{
40-
AggregateExec, AggregateMode, PhysicalGroupBy,
40+
AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy,
4141
};
4242
use datafusion_physical_plan::displayable;
4343
use datafusion_physical_plan::repartition::RepartitionExec;
@@ -260,7 +260,7 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> {
260260
schema,
261261
)
262262
.unwrap()
263-
.with_limit(Some(5)),
263+
.with_limit_options(Some(LimitOptions::new(5))),
264264
);
265265
let plan: Arc<dyn ExecutionPlan> = final_agg;
266266
// should combine the Partial/Final AggregateExecs to a Single AggregateExec

datafusion/physical-optimizer/src/combine_partial_final_agg.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
9898
Arc::clone(input_agg_exec.input()),
9999
input_agg_exec.input_schema(),
100100
)
101-
.map(|combined_agg| combined_agg.with_limit(agg_exec.limit()))
101+
.map(|combined_agg| {
102+
combined_agg.with_limit_options(agg_exec.limit_options())
103+
})
102104
.ok()
103105
.map(Arc::new)
104106
} else {

datafusion/physical-optimizer/src/limited_distinct_aggregation.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
2121
use std::sync::Arc;
2222

23-
use datafusion_physical_plan::aggregates::AggregateExec;
23+
use datafusion_physical_plan::aggregates::{AggregateExec, LimitOptions};
2424
use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
2525
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
2626

@@ -63,7 +63,7 @@ impl LimitedDistinctAggregation {
6363
aggr.input_schema(),
6464
)
6565
.expect("Unable to copy Aggregate!")
66-
.with_limit(Some(limit));
66+
.with_limit_options(Some(LimitOptions::new(limit)));
6767
Some(Arc::new(new_aggr))
6868
}
6969

datafusion/physical-optimizer/src/topk_aggregation.rs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions;
2525
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
2626
use datafusion_physical_expr::expressions::Column;
2727
use datafusion_physical_plan::ExecutionPlan;
28+
use datafusion_physical_plan::aggregates::LimitOptions;
2829
use datafusion_physical_plan::aggregates::{AggregateExec, topk_types_supported};
2930
use datafusion_physical_plan::execution_plan::CardinalityEffect;
3031
use datafusion_physical_plan::projection::ProjectionExec;
@@ -47,28 +48,47 @@ impl TopKAggregation {
4748
order_desc: bool,
4849
limit: usize,
4950
) -> Option<Arc<dyn ExecutionPlan>> {
50-
// ensure the sort direction matches aggregate function
51-
let (field, desc) = aggr.get_minmax_desc()?;
52-
if desc != order_desc {
53-
return None;
54-
}
55-
let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
56-
let kt = group_key.0.data_type(&aggr.input().schema()).ok()?;
57-
let vt = field.data_type();
58-
if !topk_types_supported(&kt, vt) {
51+
// Current only support single group key
52+
let (group_key, group_key_alias) =
53+
aggr.group_expr().expr().iter().exactly_one().ok()?;
54+
let kt = group_key.data_type(&aggr.input().schema()).ok()?;
55+
let vt = if let Some((field, _)) = aggr.get_minmax_desc() {
56+
field.data_type().clone()
57+
} else {
58+
kt.clone()
59+
};
60+
if !topk_types_supported(&kt, &vt) {
5961
return None;
6062
}
6163
if aggr.filter_expr().iter().any(|e| e.is_some()) {
6264
return None;
6365
}
6466

65-
// ensure the sort is on the same field as the aggregate output
66-
if order_by != field.name() {
67+
// Check if this is ordering by an aggregate function (MIN/MAX)
68+
if let Some((field, desc)) = aggr.get_minmax_desc() {
69+
// ensure the sort direction matches aggregate function
70+
if desc != order_desc {
71+
return None;
72+
}
73+
// ensure the sort is on the same field as the aggregate output
74+
if order_by != field.name() {
75+
return None;
76+
}
77+
} else if aggr.aggr_expr().is_empty() {
78+
// This is a GROUP BY without aggregates, check if ordering is on the group key itself
79+
if order_by != group_key_alias {
80+
return None;
81+
}
82+
} else {
83+
// Has aggregates but not MIN/MAX, or doesn't DISTINCT
6784
return None;
6885
}
6986

7087
// We found what we want: clone, copy the limit down, and return modified node
71-
let new_aggr = aggr.with_new_limit(Some(limit));
88+
let new_aggr = AggregateExec::with_new_limit_options(
89+
aggr,
90+
Some(LimitOptions::new_with_order(limit, order_desc)),
91+
);
7292
Some(Arc::new(new_aggr))
7393
}
7494

0 commit comments

Comments
 (0)