1+ // Standalone H2O groupby Q8 benchmark: PartitionedTopKExec enabled vs disabled
2+ //
3+ // Usage:
4+ // cargo run --release --example h2o_window_topn_bench
5+ //
6+ // Generates 10M rows in-memory (matching H2O SMALL), then runs Q8
7+ // (ROW_NUMBER top-2 per partition) with the optimization on and off.
8+
9+ use datafusion:: prelude:: * ;
10+ use datafusion_common:: instant:: Instant ;
11+ use std:: sync:: Arc ;
12+
13+ use arrow:: array:: { Int64Array , Float64Array , RecordBatch } ;
14+ use arrow:: datatypes:: { DataType , Field , Schema } ;
15+ use datafusion:: datasource:: MemTable ;
16+ use rand:: rngs:: StdRng ;
17+ use rand:: { Rng , SeedableRng } ;
18+
19+ const NUM_ROWS : usize = 10_000_000 ; // 10M rows (H2O SMALL)
20+ const BATCH_SIZE : usize = 100_000 ;
21+ const ITERATIONS : usize = 3 ;
22+
23+ fn generate_data ( num_partitions : i64 ) -> Arc < MemTable > {
24+ let schema = Arc :: new ( Schema :: new ( vec ! [
25+ Field :: new( "id6" , DataType :: Int64 , false ) ,
26+ Field :: new( "v3" , DataType :: Float64 , true ) ,
27+ ] ) ) ;
28+
29+ let mut rng = StdRng :: seed_from_u64 ( 42 ) ;
30+ let mut batches = Vec :: new ( ) ;
31+ let mut remaining = NUM_ROWS ;
32+
33+ while remaining > 0 {
34+ let batch_len = remaining. min ( BATCH_SIZE ) ;
35+ remaining -= batch_len;
36+
37+ let id6: Int64Array = ( 0 ..batch_len)
38+ . map ( |_| rng. random_range ( 0 ..num_partitions) )
39+ . collect ( ) ;
40+
41+ let v3: Float64Array = ( 0 ..batch_len)
42+ . map ( |_| {
43+ if rng. random_range ( 0 ..100 ) < 5 {
44+ None // 5% nulls
45+ } else {
46+ Some ( rng. random_range ( 0.0 ..1000.0 ) )
47+ }
48+ } )
49+ . collect ( ) ;
50+
51+ batches. push (
52+ RecordBatch :: try_new ( Arc :: clone ( & schema) , vec ! [
53+ Arc :: new( id6) ,
54+ Arc :: new( v3) ,
55+ ] )
56+ . unwrap ( ) ,
57+ ) ;
58+ }
59+
60+ // Split into 8 partitions
61+ let partition_size = batches. len ( ) / 8 ;
62+ let mut partitions: Vec < Vec < RecordBatch > > = Vec :: new ( ) ;
63+ for chunk in batches. chunks ( partition_size. max ( 1 ) ) {
64+ partitions. push ( chunk. to_vec ( ) ) ;
65+ }
66+
67+ Arc :: new ( MemTable :: try_new ( schema, partitions) . unwrap ( ) )
68+ }
69+
70+ const Q8 : & str = "\
71+ SELECT id6, largest2_v3 FROM (\
72+ SELECT id6, v3 AS largest2_v3, \
73+ ROW_NUMBER() OVER (PARTITION BY id6 ORDER BY v3 DESC) AS order_v3 \
74+ FROM x WHERE v3 IS NOT NULL\
75+ ) sub_query WHERE order_v3 <= 2";
76+
77+ #[ tokio:: main]
78+ async fn main ( ) {
79+ // Test across different partition cardinalities
80+ let scenarios = [
81+ ( 100 , "100 partitions (100K rows/partition)" ) ,
82+ ( 1_000 , "1K partitions (10K rows/partition)" ) ,
83+ ( 10_000 , "10K partitions (1K rows/partition)" ) ,
84+ ( 100_000 , "100K partitions (100 rows/partition, H2O-like)" ) ,
85+ ] ;
86+
87+ for ( num_partitions, label) in scenarios {
88+ println ! ( "=== Scenario: {label} ===" ) ;
89+ println ! ( "Generating {NUM_ROWS} rows with {num_partitions} partitions..." ) ;
90+ let table = generate_data ( num_partitions) ;
91+
92+ for ( tag, enabled) in [ ( "ENABLED " , true ) , ( "DISABLED" , false ) ] {
93+ let mut config = SessionConfig :: new ( ) ;
94+ config. options_mut ( ) . optimizer . enable_window_topn = enabled;
95+ let ctx = SessionContext :: new_with_config ( config) ;
96+ ctx. register_table ( "x" , Arc :: clone ( & table) as _ ) . unwrap ( ) ;
97+
98+ // Warmup
99+ let df = ctx. sql ( Q8 ) . await . unwrap ( ) ;
100+ let _ = df. collect ( ) . await . unwrap ( ) ;
101+
102+ // Benchmark
103+ let mut times = Vec :: new ( ) ;
104+ for _ in 0 ..ITERATIONS {
105+ let start = Instant :: now ( ) ;
106+ let df = ctx. sql ( Q8 ) . await . unwrap ( ) ;
107+ let batches = df. collect ( ) . await . unwrap ( ) ;
108+ let elapsed = start. elapsed ( ) ;
109+ let _row_count: usize = batches. iter ( ) . map ( |b| b. num_rows ( ) ) . sum ( ) ;
110+ times. push ( elapsed. as_millis ( ) ) ;
111+ }
112+
113+ let avg = times. iter ( ) . sum :: < u128 > ( ) / times. len ( ) as u128 ;
114+ let min = * times. iter ( ) . min ( ) . unwrap ( ) ;
115+ println ! ( " [{tag}] avg={avg} ms, min={min} ms" ) ;
116+ }
117+ println ! ( ) ;
118+ }
119+ }
0 commit comments