Skip to content

Commit eb5dae5

Browse files
committed
fix: preserve duplicate GROUPING SETS rows
1 parent 4010a55 commit eb5dae5

File tree

13 files changed

+427
-110
lines changed

13 files changed

+427
-110
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -654,14 +654,14 @@ impl DataFrame {
654654
.aggregate(group_expr, aggr_expr)?
655655
.build()?;
656656
let plan = if is_grouping_set {
657-
let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len;
657+
let grouping_id_pos = plan.schema().fields().len() - 2 - aggr_expr_len;
658658
// For grouping sets we do a project to not expose the internal grouping id
659659
let exprs = plan
660660
.schema()
661661
.columns()
662662
.into_iter()
663663
.enumerate()
664-
.filter(|(idx, _)| *idx != grouping_id_pos)
664+
.filter(|(idx, _)| *idx < grouping_id_pos || *idx >= grouping_id_pos + 2)
665665
.map(|(_, column)| Expr::Column(column))
666666
.collect::<Vec<_>>();
667667
LogicalPlanBuilder::from(plan).project(exprs)?.build()?

datafusion/core/tests/sql/aggregates/basic.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,83 @@ async fn count_aggregated_cube() -> Result<()> {
175175
Ok(())
176176
}
177177

178+
#[tokio::test]
179+
async fn duplicate_grouping_sets_are_preserved() -> Result<()> {
180+
let ctx = SessionContext::new();
181+
let schema = Arc::new(Schema::new(vec![
182+
Field::new("deptno", DataType::Int32, false),
183+
Field::new("job", DataType::Utf8, true),
184+
Field::new("sal", DataType::Int32, true),
185+
Field::new("comm", DataType::Int32, true),
186+
]));
187+
let batch = RecordBatch::try_new(
188+
Arc::clone(&schema),
189+
vec![
190+
Arc::new(Int32Array::from(vec![10, 20])),
191+
Arc::new(StringArray::from(vec![Some("CLERK"), Some("MANAGER")])),
192+
Arc::new(Int32Array::from(vec![1300, 3000])),
193+
Arc::new(Int32Array::from(vec![None, None])),
194+
],
195+
)?;
196+
let provider = MemTable::try_new(Arc::clone(&schema), vec![vec![batch]])?;
197+
ctx.register_table("dup_grouping_sets", Arc::new(provider))?;
198+
199+
let results = plan_and_collect(
200+
&ctx,
201+
"
202+
SELECT deptno, job, sal, sum(comm) AS sum_comm,
203+
grouping(deptno) AS deptno_flag,
204+
grouping(job) AS job_flag,
205+
grouping(sal) AS sal_flag
206+
FROM dup_grouping_sets
207+
GROUP BY GROUPING SETS ((deptno, job), (deptno, sal), (deptno, job))
208+
ORDER BY deptno, job, sal, deptno_flag, job_flag, sal_flag
209+
",
210+
)
211+
.await?;
212+
213+
assert_eq!(results.len(), 1);
214+
assert_snapshot!(batches_to_string(&results), @r"
215+
+--------+---------+------+----------+-------------+----------+----------+
216+
| deptno | job | sal | sum_comm | deptno_flag | job_flag | sal_flag |
217+
+--------+---------+------+----------+-------------+----------+----------+
218+
| 10 | CLERK | | | 0 | 0 | 1 |
219+
| 10 | CLERK | | | 0 | 0 | 1 |
220+
| 10 | | 1300 | | 0 | 1 | 0 |
221+
| 20 | MANAGER | | | 0 | 0 | 1 |
222+
| 20 | MANAGER | | | 0 | 0 | 1 |
223+
| 20 | | 3000 | | 0 | 1 | 0 |
224+
+--------+---------+------+----------+-------------+----------+----------+
225+
");
226+
227+
let results = plan_and_collect(
228+
&ctx,
229+
"
230+
SELECT deptno, job, sal,
231+
grouping(deptno, job, sal) AS grouping_id
232+
FROM dup_grouping_sets
233+
GROUP BY GROUPING SETS ((deptno, job), (deptno, sal), (deptno, job))
234+
ORDER BY deptno, job, sal, grouping_id
235+
",
236+
)
237+
.await?;
238+
239+
assert_eq!(results.len(), 1);
240+
assert_snapshot!(batches_to_string(&results), @r"
241+
+--------+---------+------+-------------+
242+
| deptno | job | sal | grouping_id |
243+
+--------+---------+------+-------------+
244+
| 10 | CLERK | | 1 |
245+
| 10 | CLERK | | 1 |
246+
| 10 | | 1300 | 2 |
247+
| 20 | MANAGER | | 1 |
248+
| 20 | MANAGER | | 1 |
249+
| 20 | | 3000 | 2 |
250+
+--------+---------+------+-------------+
251+
");
252+
Ok(())
253+
}
254+
178255
async fn run_count_distinct_integers_aggregated_scenario(
179256
partitions: Vec<Vec<(&str, u64)>>,
180257
) -> Result<Vec<RecordBatch>> {

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3528,6 +3528,11 @@ impl Aggregate {
35283528
)
35293529
.into(),
35303530
));
3531+
qualified_fields.push((
3532+
None,
3533+
Field::new(Self::INTERNAL_GROUPING_ORDINAL, DataType::UInt32, false)
3534+
.into(),
3535+
));
35313536
}
35323537

35333538
qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);
@@ -3592,9 +3597,13 @@ impl Aggregate {
35923597
static INTERNAL_ID_EXPR: LazyLock<Expr> = LazyLock::new(|| {
35933598
Expr::Column(Column::from_name(Aggregate::INTERNAL_GROUPING_ID))
35943599
});
3600+
static INTERNAL_ORDINAL_EXPR: LazyLock<Expr> = LazyLock::new(|| {
3601+
Expr::Column(Column::from_name(Aggregate::INTERNAL_GROUPING_ORDINAL))
3602+
});
35953603
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
35963604
if self.is_grouping_set() {
35973605
exprs.push(&INTERNAL_ID_EXPR);
3606+
exprs.push(&INTERNAL_ORDINAL_EXPR);
35983607
}
35993608
exprs.extend(self.aggr_expr.iter());
36003609
debug_assert!(exprs.len() == self.schema.fields().len());
@@ -3642,6 +3651,13 @@ impl Aggregate {
36423651
/// with `NULL` values. To handle these cases correctly, we must distinguish
36433652
/// between an actual `NULL` value in a column and a column being excluded from the set.
36443653
pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
3654+
3655+
/// Internal column used when duplicate grouping sets are present.
3656+
///
3657+
/// This column stores the ordinal of the grouping set among all grouping sets
3658+
/// with the same semantic grouping mask, allowing the physical aggregation key
3659+
/// to distinguish duplicate grouping sets without overloading `__grouping_id`.
3660+
pub const INTERNAL_GROUPING_ORDINAL: &'static str = "__grouping_ordinal";
36453661
}
36463662

36473663
// Manual implementation needed because of `schema` field. Comparison excludes this field.

datafusion/expr/src/utils.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
5959
"Invalid group by expressions, GroupingSet must be the only expression"
6060
);
6161
}
62-
// Groupings sets have an additional integral column for the grouping id
63-
Ok(grouping_set.distinct_expr().len() + 1)
62+
// Grouping sets have additional internal columns for grouping semantics and
63+
// duplicate-grouping-set ordinals.
64+
Ok(grouping_set.distinct_expr().len() + 2)
6465
} else {
6566
grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
6667
}

datafusion/optimizer/src/analyzer/resolve_grouping_function.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ fn replace_grouping_exprs(
8484
let columns = schema.columns();
8585
let mut new_agg_expr = Vec::new();
8686
let mut projection_exprs = Vec::new();
87-
let grouping_id_len = if is_grouping_set { 1 } else { 0 };
87+
let grouping_id_len = if is_grouping_set { 2 } else { 0 };
8888
let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len;
8989
projection_exprs.extend(
9090
columns
@@ -204,7 +204,6 @@ fn grouping_function_on_id(
204204
Expr::Literal(ScalarValue::from(value as u64), None)
205205
}
206206
};
207-
208207
let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
209208
// The grouping call is exactly our internal grouping id
210209
if args.len() == group_by_expr_count

datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ fn aggregate_output_exprs(group_expr: &[Expr]) -> Result<Vec<Expr>> {
238238
output_exprs.push(Expr::Column(Column::from_name(
239239
Aggregate::INTERNAL_GROUPING_ID,
240240
)));
241+
output_exprs.push(Expr::Column(Column::from_name(
242+
Aggregate::INTERNAL_GROUPING_ORDINAL,
243+
)));
241244
}
242245

243246
Ok(output_exprs)

datafusion/optimizer/src/single_distinct_to_groupby.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ mod tests {
380380
assert_optimized_plan_equal!(
381381
plan,
382382
@r"
383-
Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
383+
Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, __grouping_ordinal:UInt32, count(DISTINCT test.c):Int64]
384384
TableScan: test [a:UInt32, b:UInt32, c:UInt32]
385385
"
386386
)
@@ -401,7 +401,7 @@ mod tests {
401401
assert_optimized_plan_equal!(
402402
plan,
403403
@r"
404-
Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
404+
Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, __grouping_ordinal:UInt32, count(DISTINCT test.c):Int64]
405405
TableScan: test [a:UInt32, b:UInt32, c:UInt32]
406406
"
407407
)
@@ -423,7 +423,7 @@ mod tests {
423423
assert_optimized_plan_equal!(
424424
plan,
425425
@r"
426-
Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]
426+
Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, __grouping_ordinal:UInt32, count(DISTINCT test.c):Int64]
427427
TableScan: test [a:UInt32, b:UInt32, c:UInt32]
428428
"
429429
)

datafusion/physical-optimizer/src/enforce_distribution.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,25 +1280,31 @@ pub fn ensure_distribution(
12801280
// Allow subset satisfaction when:
12811281
// 1. Current partition count >= threshold
12821282
// 2. Not a partitioned join since must use exact hash matching for joins
1283-
// 3. Not a grouping set aggregate (requires exact hash including __grouping_id)
1283+
// 3. Not a grouping set aggregate (requires exact hash including internal grouping columns)
12841284
let current_partitions = child.plan.output_partitioning().partition_count();
12851285

1286-
// Check if the hash partitioning requirement includes __grouping_id column.
1286+
// Check if the hash partitioning requirement includes internal grouping columns.
12871287
// Grouping set aggregates (ROLLUP, CUBE, GROUPING SETS) require exact hash
1288-
// partitioning on all group columns including __grouping_id to ensure partial
1289-
// aggregates from different partitions are correctly combined.
1290-
let requires_grouping_id = matches!(&requirement, Distribution::HashPartitioned(exprs)
1288+
// partitioning on all group columns including the internal grouping columns
1289+
// to ensure partial aggregates from different partitions are correctly combined.
1290+
let requires_grouping_key = matches!(&requirement, Distribution::HashPartitioned(exprs)
12911291
if exprs.iter().any(|expr| {
12921292
expr.as_any()
12931293
.downcast_ref::<Column>()
1294-
.is_some_and(|col| col.name() == Aggregate::INTERNAL_GROUPING_ID)
1294+
.is_some_and(|col| {
1295+
matches!(
1296+
col.name(),
1297+
Aggregate::INTERNAL_GROUPING_ID
1298+
| Aggregate::INTERNAL_GROUPING_ORDINAL
1299+
)
1300+
})
12951301
})
12961302
);
12971303

12981304
let allow_subset_satisfy_partitioning = current_partitions
12991305
>= subset_satisfaction_threshold
13001306
&& !is_partitioned_join
1301-
&& !requires_grouping_id;
1307+
&& !requires_grouping_key;
13021308

13031309
// When `repartition_file_scans` is set, attempt to increase
13041310
// parallelism at the source.

0 commit comments

Comments
 (0)