Skip to content

Commit 1ab7e41

Browse files
Fix grouping set subset satisfaction (#19853)
## Summary - Fixes incorrect results from ROLLUP/CUBE/GROUPING SETS queries when using multiple partitions - The subset satisfaction optimization was incorrectly allowing hash partitioning on fewer columns to satisfy requirements that include `__grouping_id` - This caused partial aggregates from different partitions to be finalized independently, producing duplicate grand totals Closes #19849
1 parent e353eb0 commit 1ab7e41

File tree

2 files changed

+263
-2
lines changed

2 files changed

+263
-2
lines changed

datafusion/physical-optimizer/src/enforce_distribution.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ use datafusion_common::config::ConfigOptions;
3636
use datafusion_common::error::Result;
3737
use datafusion_common::stats::Precision;
3838
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
39-
use datafusion_expr::logical_plan::JoinType;
39+
use datafusion_expr::logical_plan::{Aggregate, JoinType};
4040
use datafusion_physical_expr::expressions::{Column, NoOp};
4141
use datafusion_physical_expr::utils::map_columns_before_projection;
4242
use datafusion_physical_expr::{
@@ -1301,10 +1301,25 @@ pub fn ensure_distribution(
13011301
// Allow subset satisfaction when:
13021302
// 1. Current partition count >= threshold
13031303
// 2. Not a partitioned join since must use exact hash matching for joins
1304+
// 3. Not a grouping set aggregate (requires exact hash including __grouping_id)
13041305
let current_partitions = child.plan.output_partitioning().partition_count();
1306+
1307+
// Check if the hash partitioning requirement includes __grouping_id column.
1308+
// Grouping set aggregates (ROLLUP, CUBE, GROUPING SETS) require exact hash
1309+
// partitioning on all group columns including __grouping_id to ensure partial
1310+
// aggregates from different partitions are correctly combined.
1311+
let requires_grouping_id = matches!(&requirement, Distribution::HashPartitioned(exprs)
1312+
if exprs.iter().any(|expr| {
1313+
expr.as_any()
1314+
.downcast_ref::<Column>()
1315+
.is_some_and(|col| col.name() == Aggregate::INTERNAL_GROUPING_ID)
1316+
})
1317+
);
1318+
13051319
let allow_subset_satisfy_partitioning = current_partitions
13061320
>= subset_satisfaction_threshold
1307-
&& !is_partitioned_join;
1321+
&& !is_partitioned_join
1322+
&& !requires_grouping_id;
13081323

13091324
// When `repartition_file_scans` is set, attempt to increase
13101325
// parallelism at the source.
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
##########
19+
# Tests for ROLLUP/CUBE/GROUPING SETS with multiple partitions
20+
#
21+
# This tests the fix for https://github.com/apache/datafusion/issues/19849
22+
# where ROLLUP queries produced incorrect results with multiple partitions
23+
# because subset partitioning satisfaction was incorrectly applied.
24+
#
25+
# The bug manifests when:
26+
# 1. UNION ALL of subqueries each with hash-partitioned aggregates
27+
# 2. Outer ROLLUP groups by more columns than inner hash partitioning
28+
# 3. InterleaveExec preserves the inner hash partitioning
29+
# 4. Optimizer incorrectly uses subset satisfaction, skipping necessary repartition
30+
#
31+
# The fix ensures that when hash partitioning includes __grouping_id,
32+
# subset satisfaction is disabled and proper RepartitionExec is inserted.
33+
##########
34+
35+
##########
36+
# SETUP: Create partitioned parquet files to simulate distributed data
37+
##########
38+
39+
statement ok
40+
set datafusion.execution.target_partitions = 4;
41+
42+
statement ok
43+
set datafusion.optimizer.repartition_aggregations = true;
44+
45+
# Create partition 1
46+
statement ok
47+
COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES
48+
('store', 'nike', 100),
49+
('store', 'nike', 200),
50+
('store', 'adidas', 150)
51+
))
52+
TO 'test_files/scratch/grouping_set_repartition/part=1/data.parquet'
53+
STORED AS PARQUET;
54+
55+
# Create partition 2
56+
statement ok
57+
COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES
58+
('store', 'adidas', 250),
59+
('web', 'nike', 300),
60+
('web', 'nike', 400)
61+
))
62+
TO 'test_files/scratch/grouping_set_repartition/part=2/data.parquet'
63+
STORED AS PARQUET;
64+
65+
# Create partition 3
66+
statement ok
67+
COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES
68+
('web', 'adidas', 350),
69+
('web', 'adidas', 450),
70+
('catalog', 'nike', 500)
71+
))
72+
TO 'test_files/scratch/grouping_set_repartition/part=3/data.parquet'
73+
STORED AS PARQUET;
74+
75+
# Create partition 4
76+
statement ok
77+
COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES
78+
('catalog', 'nike', 600),
79+
('catalog', 'adidas', 550),
80+
('catalog', 'adidas', 650)
81+
))
82+
TO 'test_files/scratch/grouping_set_repartition/part=4/data.parquet'
83+
STORED AS PARQUET;
84+
85+
# Create external table pointing to the partitioned data
86+
statement ok
87+
CREATE EXTERNAL TABLE sales (channel VARCHAR, brand VARCHAR, amount INT)
88+
STORED AS PARQUET
89+
PARTITIONED BY (part INT)
90+
LOCATION 'test_files/scratch/grouping_set_repartition/';
91+
92+
##########
93+
# TEST 1: UNION ALL + ROLLUP pattern (similar to TPC-DS q14)
94+
# This query pattern triggers the subset satisfaction bug because:
95+
# - Each UNION ALL branch has hash partitioning on (brand)
96+
# - The outer ROLLUP requires hash partitioning on (channel, brand, __grouping_id)
97+
# - Without the fix, subset satisfaction incorrectly skips repartition
98+
#
99+
# Verify the physical plan includes RepartitionExec with __grouping_id
100+
##########
101+
102+
query TT
103+
EXPLAIN SELECT channel, brand, SUM(total) as grand_total
104+
FROM (
105+
SELECT 'store' as channel, brand, SUM(amount) as total
106+
FROM sales WHERE channel = 'store'
107+
GROUP BY brand
108+
UNION ALL
109+
SELECT 'web' as channel, brand, SUM(amount) as total
110+
FROM sales WHERE channel = 'web'
111+
GROUP BY brand
112+
UNION ALL
113+
SELECT 'catalog' as channel, brand, SUM(amount) as total
114+
FROM sales WHERE channel = 'catalog'
115+
GROUP BY brand
116+
) sub
117+
GROUP BY ROLLUP(channel, brand)
118+
ORDER BY channel NULLS FIRST, brand NULLS FIRST;
119+
----
120+
logical_plan
121+
01)Sort: sub.channel ASC NULLS FIRST, sub.brand ASC NULLS FIRST
122+
02)--Projection: sub.channel, sub.brand, sum(sub.total) AS grand_total
123+
03)----Aggregate: groupBy=[[ROLLUP (sub.channel, sub.brand)]], aggr=[[sum(sub.total)]]
124+
04)------SubqueryAlias: sub
125+
05)--------Union
126+
06)----------Projection: Utf8("store") AS channel, sales.brand, sum(sales.amount) AS total
127+
07)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]]
128+
08)--------------Projection: sales.brand, sales.amount
129+
09)----------------Filter: sales.channel = Utf8View("store")
130+
10)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("store")]
131+
11)----------Projection: Utf8("web") AS channel, sales.brand, sum(sales.amount) AS total
132+
12)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]]
133+
13)--------------Projection: sales.brand, sales.amount
134+
14)----------------Filter: sales.channel = Utf8View("web")
135+
15)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("web")]
136+
16)----------Projection: Utf8("catalog") AS channel, sales.brand, sum(sales.amount) AS total
137+
17)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]]
138+
18)--------------Projection: sales.brand, sales.amount
139+
19)----------------Filter: sales.channel = Utf8View("catalog")
140+
20)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("catalog")]
141+
physical_plan
142+
01)SortPreservingMergeExec: [channel@0 ASC, brand@1 ASC]
143+
02)--SortExec: expr=[channel@0 ASC, brand@1 ASC], preserve_partitioning=[true]
144+
03)----ProjectionExec: expr=[channel@0 as channel, brand@1 as brand, sum(sub.total)@3 as grand_total]
145+
04)------AggregateExec: mode=FinalPartitioned, gby=[channel@0 as channel, brand@1 as brand, __grouping_id@2 as __grouping_id], aggr=[sum(sub.total)]
146+
05)--------RepartitionExec: partitioning=Hash([channel@0, brand@1, __grouping_id@2], 4), input_partitions=4
147+
06)----------AggregateExec: mode=Partial, gby=[(NULL as channel, NULL as brand), (channel@0 as channel, NULL as brand), (channel@0 as channel, brand@1 as brand)], aggr=[sum(sub.total)]
148+
07)------------InterleaveExec
149+
08)--------------ProjectionExec: expr=[store as channel, brand@0 as brand, sum(sales.amount)@1 as total]
150+
09)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
151+
10)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4
152+
11)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
153+
12)----------------------FilterExec: channel@0 = store, projection=[brand@1, amount@2]
154+
13)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = store, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= store AND store <= channel_max@1, required_guarantees=[channel in (store)]
155+
14)--------------ProjectionExec: expr=[web as channel, brand@0 as brand, sum(sales.amount)@1 as total]
156+
15)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
157+
16)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4
158+
17)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
159+
18)----------------------FilterExec: channel@0 = web, projection=[brand@1, amount@2]
160+
19)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = web, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= web AND web <= channel_max@1, required_guarantees=[channel in (web)]
161+
20)--------------ProjectionExec: expr=[catalog as channel, brand@0 as brand, sum(sales.amount)@1 as total]
162+
21)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
163+
22)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4
164+
23)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
165+
24)----------------------FilterExec: channel@0 = catalog, projection=[brand@1, amount@2]
166+
25)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = catalog, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= catalog AND catalog <= channel_max@1, required_guarantees=[channel in (catalog)]
167+
168+
query TTI rowsort
169+
SELECT channel, brand, SUM(total) as grand_total
170+
FROM (
171+
SELECT 'store' as channel, brand, SUM(amount) as total
172+
FROM sales WHERE channel = 'store'
173+
GROUP BY brand
174+
UNION ALL
175+
SELECT 'web' as channel, brand, SUM(amount) as total
176+
FROM sales WHERE channel = 'web'
177+
GROUP BY brand
178+
UNION ALL
179+
SELECT 'catalog' as channel, brand, SUM(amount) as total
180+
FROM sales WHERE channel = 'catalog'
181+
GROUP BY brand
182+
) sub
183+
GROUP BY ROLLUP(channel, brand)
184+
ORDER BY channel NULLS FIRST, brand NULLS FIRST;
185+
----
186+
NULL NULL 4500
187+
catalog NULL 2300
188+
catalog adidas 1200
189+
catalog nike 1100
190+
store NULL 700
191+
store adidas 400
192+
store nike 300
193+
web NULL 1500
194+
web adidas 800
195+
web nike 700
196+
197+
##########
198+
# TEST 2: Simple ROLLUP (baseline test)
199+
##########
200+
201+
query TTI rowsort
202+
SELECT channel, brand, SUM(amount) as total
203+
FROM sales
204+
GROUP BY ROLLUP(channel, brand)
205+
ORDER BY channel NULLS FIRST, brand NULLS FIRST;
206+
----
207+
NULL NULL 4500
208+
catalog NULL 2300
209+
catalog adidas 1200
210+
catalog nike 1100
211+
store NULL 700
212+
store adidas 400
213+
store nike 300
214+
web NULL 1500
215+
web adidas 800
216+
web nike 700
217+
218+
##########
219+
# TEST 3: Verify CUBE also works correctly
220+
##########
221+
222+
query TTI rowsort
223+
SELECT channel, brand, SUM(amount) as total
224+
FROM sales
225+
GROUP BY CUBE(channel, brand)
226+
ORDER BY channel NULLS FIRST, brand NULLS FIRST;
227+
----
228+
NULL NULL 4500
229+
NULL adidas 2400
230+
NULL nike 2100
231+
catalog NULL 2300
232+
catalog adidas 1200
233+
catalog nike 1100
234+
store NULL 700
235+
store adidas 400
236+
store nike 300
237+
web NULL 1500
238+
web adidas 800
239+
web nike 700
240+
241+
##########
242+
# CLEANUP
243+
##########
244+
245+
statement ok
246+
DROP TABLE sales;

0 commit comments

Comments
 (0)