Skip to content

Commit 38fa07a

Browse files
Subham SinghalSubham Singhal
authored andcommitted
Benchmark window topn optimisation
1 parent 5ba06ac commit 38fa07a

11 files changed

Lines changed: 1500 additions & 1 deletion

File tree

datafusion/common/src/config.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,12 @@ config_namespace! {
10871087
/// past window functions, if possible
10881088
pub enable_window_limits: bool, default = true
10891089

1090+
/// When set to true, the optimizer will replace
1091+
/// Filter(rn<=K) → Window(ROW_NUMBER) → Sort patterns with a
1092+
/// PartitionedTopKExec that maintains per-partition heaps, avoiding
1093+
/// a full sort of the input.
1094+
pub enable_window_topn: bool, default = true
1095+
10901096
/// When set to true, the optimizer will push TopK (Sort with fetch)
10911097
/// below hash repartition when the partition key is a prefix of the
10921098
/// sort key, reducing data volume before the shuffle.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Standalone H2O groupby Q8 benchmark: PartitionedTopKExec enabled vs disabled
2+
//
3+
// Usage:
4+
// cargo run --release --example h2o_window_topn_bench
5+
//
6+
// Generates 10M rows in-memory (matching H2O SMALL), then runs Q8
7+
// (ROW_NUMBER top-2 per partition) with the optimization on and off.
8+
9+
use datafusion::prelude::*;
10+
use datafusion_common::instant::Instant;
11+
use std::sync::Arc;
12+
13+
use arrow::array::{Int64Array, Float64Array, RecordBatch};
14+
use arrow::datatypes::{DataType, Field, Schema};
15+
use datafusion::datasource::MemTable;
16+
use rand::rngs::StdRng;
17+
use rand::{Rng, SeedableRng};
18+
19+
const NUM_ROWS: usize = 10_000_000; // 10M rows (H2O SMALL)
20+
const BATCH_SIZE: usize = 100_000;
21+
const ITERATIONS: usize = 3;
22+
23+
fn generate_data(num_partitions: i64) -> Arc<MemTable> {
24+
let schema = Arc::new(Schema::new(vec![
25+
Field::new("id6", DataType::Int64, false),
26+
Field::new("v3", DataType::Float64, true),
27+
]));
28+
29+
let mut rng = StdRng::seed_from_u64(42);
30+
let mut batches = Vec::new();
31+
let mut remaining = NUM_ROWS;
32+
33+
while remaining > 0 {
34+
let batch_len = remaining.min(BATCH_SIZE);
35+
remaining -= batch_len;
36+
37+
let id6: Int64Array = (0..batch_len)
38+
.map(|_| rng.random_range(0..num_partitions))
39+
.collect();
40+
41+
let v3: Float64Array = (0..batch_len)
42+
.map(|_| {
43+
if rng.random_range(0..100) < 5 {
44+
None // 5% nulls
45+
} else {
46+
Some(rng.random_range(0.0..1000.0))
47+
}
48+
})
49+
.collect();
50+
51+
batches.push(
52+
RecordBatch::try_new(Arc::clone(&schema), vec![
53+
Arc::new(id6),
54+
Arc::new(v3),
55+
])
56+
.unwrap(),
57+
);
58+
}
59+
60+
// Split into 8 partitions
61+
let partition_size = batches.len() / 8;
62+
let mut partitions: Vec<Vec<RecordBatch>> = Vec::new();
63+
for chunk in batches.chunks(partition_size.max(1)) {
64+
partitions.push(chunk.to_vec());
65+
}
66+
67+
Arc::new(MemTable::try_new(schema, partitions).unwrap())
68+
}
69+
70+
const Q8: &str = "\
71+
SELECT id6, largest2_v3 FROM (\
72+
SELECT id6, v3 AS largest2_v3, \
73+
ROW_NUMBER() OVER (PARTITION BY id6 ORDER BY v3 DESC) AS order_v3 \
74+
FROM x WHERE v3 IS NOT NULL\
75+
) sub_query WHERE order_v3 <= 2";
76+
77+
#[tokio::main]
78+
async fn main() {
79+
// Test across different partition cardinalities
80+
let scenarios = [
81+
(100, "100 partitions (100K rows/partition)"),
82+
(1_000, "1K partitions (10K rows/partition)"),
83+
(10_000, "10K partitions (1K rows/partition)"),
84+
(100_000, "100K partitions (100 rows/partition, H2O-like)"),
85+
];
86+
87+
for (num_partitions, label) in scenarios {
88+
println!("=== Scenario: {label} ===");
89+
println!("Generating {NUM_ROWS} rows with {num_partitions} partitions...");
90+
let table = generate_data(num_partitions);
91+
92+
for (tag, enabled) in [("ENABLED ", true), ("DISABLED", false)] {
93+
let mut config = SessionConfig::new();
94+
config.options_mut().optimizer.enable_window_topn = enabled;
95+
let ctx = SessionContext::new_with_config(config);
96+
ctx.register_table("x", Arc::clone(&table) as _).unwrap();
97+
98+
// Warmup
99+
let df = ctx.sql(Q8).await.unwrap();
100+
let _ = df.collect().await.unwrap();
101+
102+
// Benchmark
103+
let mut times = Vec::new();
104+
for _ in 0..ITERATIONS {
105+
let start = Instant::now();
106+
let df = ctx.sql(Q8).await.unwrap();
107+
let batches = df.collect().await.unwrap();
108+
let elapsed = start.elapsed();
109+
let _row_count: usize = batches.iter().map(|b| b.num_rows()).sum();
110+
times.push(elapsed.as_millis());
111+
}
112+
113+
let avg = times.iter().sum::<u128>() / times.len() as u128;
114+
let min = *times.iter().min().unwrap();
115+
println!(" [{tag}] avg={avg} ms, min={min} ms");
116+
}
117+
println!();
118+
}
119+
}

datafusion/core/tests/physical_optimizer/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,6 @@ mod sanity_checker;
3737
#[expect(clippy::needless_pass_by_value)]
3838
mod test_utils;
3939
mod window_optimize;
40+
mod window_topn;
4041

4142
mod pushdown_utils;

0 commit comments

Comments
 (0)