Skip to content

Commit 3bdd9e7

Browse files
committed
fix: preserve duplicate GROUPING SETS rows
1 parent 4010a55 commit 3bdd9e7

File tree

5 files changed

+199
-35
lines changed

5 files changed

+199
-35
lines changed

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use crate::utils::{
4545
grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction,
4646
};
4747
use crate::{
48-
BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable,
48+
BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, GroupingSet,
4949
LogicalPlanBuilder, Operator, Prepare, TableProviderFilterPushDown, TableSource,
5050
WindowFunctionDefinition, build_join_schema, expr_vec_fmt, requalify_sides_if_needed,
5151
};
@@ -3519,11 +3519,12 @@ impl Aggregate {
35193519
.into_iter()
35203520
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
35213521
.collect::<Vec<_>>();
3522+
let max_ordinal = max_grouping_set_duplicate_ordinal(&group_expr);
35223523
qualified_fields.push((
35233524
None,
35243525
Field::new(
35253526
Self::INTERNAL_GROUPING_ID,
3526-
Self::grouping_id_type(qualified_fields.len()),
3527+
Self::grouping_id_type(qualified_fields.len(), max_ordinal),
35273528
false,
35283529
)
35293530
.into(),
@@ -3609,15 +3610,25 @@ impl Aggregate {
36093610
}
36103611

36113612
/// Returns the data type of the grouping id.
3612-
/// The grouping ID value is a bitmask where each set bit
3613-
/// indicates that the corresponding grouping expression is
3614-
/// null
3615-
pub fn grouping_id_type(group_exprs: usize) -> DataType {
3616-
if group_exprs <= 8 {
3613+
///
3614+
/// The grouping ID packs two pieces of information into a single integer:
3615+
/// - The low `group_exprs` bits are the semantic bitmask (a set bit means the
3616+
/// corresponding grouping expression is NULL for this grouping set).
3617+
/// - The bits above position `group_exprs` encode a duplicate ordinal that
3618+
/// distinguishes multiple occurrences of the same grouping set pattern.
3619+
///
3620+
/// `max_ordinal` is the highest ordinal value that will appear (0 when there
3621+
/// are no duplicate grouping sets). The type is chosen to be the smallest
3622+
/// unsigned integer that can represent both parts.
3623+
pub fn grouping_id_type(group_exprs: usize, max_ordinal: usize) -> DataType {
3624+
let ordinal_bits =
3625+
usize::BITS as usize - max_ordinal.leading_zeros() as usize;
3626+
let total_bits = group_exprs + ordinal_bits;
3627+
if total_bits <= 8 {
36173628
DataType::UInt8
3618-
} else if group_exprs <= 16 {
3629+
} else if total_bits <= 16 {
36193630
DataType::UInt16
3620-
} else if group_exprs <= 32 {
3631+
} else if total_bits <= 32 {
36213632
DataType::UInt32
36223633
} else {
36233634
DataType::UInt64
@@ -3626,21 +3637,36 @@ impl Aggregate {
36263637

36273638
/// Internal column used when the aggregation is a grouping set.
36283639
///
3629-
/// This column contains a bitmask where each bit represents a grouping
3630-
/// expression. The least significant bit corresponds to the rightmost
3631-
/// grouping expression. A bit value of 0 indicates that the corresponding
3632-
/// column is included in the grouping set, while a value of 1 means it is excluded.
3640+
/// This column packs two values into a single unsigned integer:
3641+
///
3642+
/// - **Low bits (positions 0 .. n-1)**: a semantic bitmask where each bit
3643+
/// represents one of the `n` grouping expressions. The least significant
3644+
/// bit corresponds to the rightmost grouping expression. A `1` bit means
3645+
/// the corresponding column is replaced with `NULL` for this grouping set;
3646+
/// a `0` bit means it is included.
3647+
/// - **High bits (positions n and above)**: a *duplicate ordinal* that
3648+
/// distinguishes multiple occurrences of the same semantic grouping set
3649+
/// pattern within a single query. The ordinal is `0` for the first
3650+
/// occurrence, `1` for the second, and so on.
36333651
///
3634-
/// For example, for the grouping expressions CUBE(a, b), the grouping ID
3635-
/// column will have the following values:
3652+
/// The integer type is chosen by [`Self::grouping_id_type`] to be the
3653+
/// smallest `UInt8 / UInt16 / UInt32 / UInt64` that can represent both
3654+
/// parts.
3655+
///
3656+
/// For example, for the grouping expressions CUBE(a, b) (no duplicates),
3657+
/// the grouping ID column will have the following values:
36363658
/// 0b00: Both `a` and `b` are included
36373659
/// 0b01: `b` is excluded
36383660
/// 0b10: `a` is excluded
36393661
/// 0b11: Both `a` and `b` are excluded
36403662
///
3641-
/// This internal column is necessary because excluded columns are replaced
3642-
/// with `NULL` values. To handle these cases correctly, we must distinguish
3643-
/// between an actual `NULL` value in a column and a column being excluded from the set.
3663+
/// When the same set appears twice and `n = 2`, the duplicate ordinal is
3664+
/// packed into bit 2:
3665+
/// first occurrence: `0b0_01` (ordinal = 0, mask = 0b01)
3666+
/// second occurrence: `0b1_01` (ordinal = 1, mask = 0b01)
3667+
///
3668+
/// The GROUPING function always masks the value with `(1 << n) - 1` before
3669+
/// interpreting it so the ordinal bits are invisible to user-facing SQL.
36443670
pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
36453671
}
36463672

@@ -3661,6 +3687,30 @@ impl PartialOrd for Aggregate {
36613687
}
36623688
}
36633689

3690+
/// Returns the highest duplicate ordinal across all grouping sets in `group_expr`.
3691+
///
3692+
/// The ordinal counts how many times a given grouping set pattern has already
3693+
/// appeared before the current occurrence. For example, if the same set
3694+
/// appears three times the ordinals are 0, 1, 2 and this function returns 2.
3695+
/// Returns 0 when no grouping set is duplicated.
3696+
fn max_grouping_set_duplicate_ordinal(group_expr: &[Expr]) -> usize {
3697+
if let Some(Expr::GroupingSet(GroupingSet::GroupingSets(sets))) = group_expr.first() {
3698+
let mut counts: HashMap<&Vec<Expr>, usize> =
3699+
HashMap::new();
3700+
for set in sets {
3701+
*counts.entry(set).or_insert(0) += 1;
3702+
}
3703+
counts
3704+
.values()
3705+
.copied()
3706+
.max()
3707+
.unwrap_or(1)
3708+
.saturating_sub(1)
3709+
} else {
3710+
0
3711+
}
3712+
}
3713+
36643714
/// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`.
36653715
fn contains_grouping_set(group_expr: &[Expr]) -> bool {
36663716
group_expr

datafusion/optimizer/src/analyzer/resolve_grouping_function.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,25 @@ fn grouping_function_on_id(
204204
Expr::Literal(ScalarValue::from(value as u64), None)
205205
}
206206
};
207-
208207
let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
209-
// The grouping call is exactly our internal grouping id
208+
// The grouping call is exactly our internal grouping id — mask the ordinal
209+
// bits (above position `n`) so only the semantic bitmask is visible.
210+
let n = group_by_expr_count;
211+
// (1 << n) - 1 masks the low n bits. Use saturating arithmetic to handle n == 0.
212+
let semantic_mask: u64 = if n >= 64 {
213+
u64::MAX
214+
} else {
215+
(1u64 << n).wrapping_sub(1)
216+
};
217+
let masked_id = bitwise_and(grouping_id_column.clone(), literal(semantic_mask as usize));
210218
if args.len() == group_by_expr_count
211219
&& args
212220
.iter()
213221
.rev()
214222
.enumerate()
215223
.all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
216224
{
217-
return Ok(cast(grouping_id_column, DataType::Int32));
225+
return Ok(cast(masked_id, DataType::Int32));
218226
}
219227

220228
args.iter()

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 77 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use crate::{
3838
use datafusion_common::config::ConfigOptions;
3939
use datafusion_physical_expr::utils::collect_columns;
4040
use parking_lot::Mutex;
41-
use std::collections::HashSet;
41+
use std::collections::{HashMap, HashSet};
4242

4343
use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array};
4444
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
@@ -397,6 +397,14 @@ impl PhysicalGroupBy {
397397
self.expr.len() + usize::from(self.has_grouping_set)
398398
}
399399

400+
/// Returns the Arrow data type of the `__grouping_id` column.
401+
///
402+
/// The type is chosen to be wide enough to hold both the semantic bitmask
403+
/// (in the low `n` bits) and the duplicate ordinal (in the high bits).
404+
fn grouping_id_data_type(&self) -> DataType {
405+
Aggregate::grouping_id_type(self.expr.len(), max_duplicate_ordinal(&self.groups))
406+
}
407+
400408
pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> {
401409
Ok(Arc::new(Schema::new(self.group_fields(schema)?)))
402410
}
@@ -421,7 +429,7 @@ impl PhysicalGroupBy {
421429
fields.push(
422430
Field::new(
423431
Aggregate::INTERNAL_GROUPING_ID,
424-
Aggregate::grouping_id_type(self.expr.len()),
432+
self.grouping_id_data_type(),
425433
false,
426434
)
427435
.into(),
@@ -1937,27 +1945,77 @@ fn evaluate_optional(
19371945
.collect()
19381946
}
19391947

1940-
fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
1941-
if group.len() > 64 {
1948+
/// Builds the internal `__grouping_id` array for a single grouping set.
1949+
///
1950+
/// The returned array packs two values into a single integer:
1951+
///
1952+
/// - Low `n` bits (positions 0 .. n-1): the semantic bitmask. A `1` bit
1953+
/// at position `i` means that the `i`-th grouping column (counting from the
1954+
/// least significant bit, i.e. the *last* column in the `group` slice) is
1955+
/// `NULL` for this grouping set.
1956+
/// - High bits (positions n and above): the duplicate `ordinal`, which
1957+
/// distinguishes multiple occurrences of the same grouping-set pattern. The
1958+
/// ordinal is `0` for the first occurrence, `1` for the second, and so on.
1959+
///
1960+
/// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 /
1961+
/// UInt64` that can represent both parts. It matches the type returned by
1962+
/// [`Aggregate::grouping_id_type`].
1963+
fn group_id_array(
1964+
group: &[bool],
1965+
ordinal: usize,
1966+
max_ordinal: usize,
1967+
batch: &RecordBatch,
1968+
) -> Result<ArrayRef> {
1969+
let n = group.len();
1970+
if n > 64 {
19421971
return not_impl_err!(
19431972
"Grouping sets with more than 64 columns are not supported"
19441973
);
19451974
}
1946-
let group_id = group.iter().fold(0u64, |acc, &is_null| {
1975+
let ordinal_bits =
1976+
usize::BITS as usize - max_ordinal.leading_zeros() as usize;
1977+
let total_bits = n + ordinal_bits;
1978+
if total_bits > 64 {
1979+
return not_impl_err!(
1980+
"Grouping sets with {n} columns and a maximum duplicate ordinal of \
1981+
{max_ordinal} require {total_bits} bits, which exceeds 64"
1982+
);
1983+
}
1984+
let semantic_id = group.iter().fold(0u64, |acc, &is_null| {
19471985
(acc << 1) | if is_null { 1 } else { 0 }
19481986
});
1987+
let full_id = semantic_id | ((ordinal as u64) << n);
19491988
let num_rows = batch.num_rows();
1950-
if group.len() <= 8 {
1951-
Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
1952-
} else if group.len() <= 16 {
1953-
Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
1954-
} else if group.len() <= 32 {
1955-
Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
1989+
if total_bits <= 8 {
1990+
Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows])))
1991+
} else if total_bits <= 16 {
1992+
Ok(Arc::new(UInt16Array::from(vec![full_id as u16; num_rows])))
1993+
} else if total_bits <= 32 {
1994+
Ok(Arc::new(UInt32Array::from(vec![full_id as u32; num_rows])))
19561995
} else {
1957-
Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
1996+
Ok(Arc::new(UInt64Array::from(vec![full_id; num_rows])))
19581997
}
19591998
}
19601999

2000+
/// Returns the highest duplicate ordinal across all grouping sets.
2001+
///
2002+
/// The ordinal counts how many times a grouping-set pattern has already
2003+
/// appeared before the current occurrence. If the same `Vec<bool>` appears
2004+
/// three times the ordinals are 0, 1, 2 and this function returns 2.
2005+
/// Returns 0 when no grouping set is duplicated.
2006+
fn max_duplicate_ordinal(groups: &[Vec<bool>]) -> usize {
2007+
let mut counts: HashMap<&Vec<bool>, usize> = HashMap::new();
2008+
for group in groups {
2009+
*counts.entry(group).or_insert(0) += 1;
2010+
}
2011+
counts
2012+
.values()
2013+
.copied()
2014+
.max()
2015+
.unwrap_or(1)
2016+
.saturating_sub(1)
2017+
}
2018+
19612019
/// Evaluate a group by expression against a `RecordBatch`
19622020
///
19632021
/// Arguments:
@@ -1972,6 +2030,8 @@ pub fn evaluate_group_by(
19722030
group_by: &PhysicalGroupBy,
19732031
batch: &RecordBatch,
19742032
) -> Result<Vec<Vec<ArrayRef>>> {
2033+
let max_ordinal = max_duplicate_ordinal(&group_by.groups);
2034+
let mut ordinal_per_pattern: HashMap<&Vec<bool>, usize> = HashMap::new();
19752035
let exprs = evaluate_expressions_to_arrays(
19762036
group_by.expr.iter().map(|(expr, _)| expr),
19772037
batch,
@@ -1985,6 +2045,10 @@ pub fn evaluate_group_by(
19852045
.groups
19862046
.iter()
19872047
.map(|group| {
2048+
let ordinal = ordinal_per_pattern.entry(group).or_insert(0);
2049+
let current_ordinal = *ordinal;
2050+
*ordinal += 1;
2051+
19882052
let mut group_values = Vec::with_capacity(group_by.num_group_exprs());
19892053
group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
19902054
if *is_null {
@@ -1994,7 +2058,7 @@ pub fn evaluate_group_by(
19942058
}
19952059
}));
19962060
if !group_by.is_single() {
1997-
group_values.push(group_id_array(group, batch)?);
2061+
group_values.push(group_id_array(group, current_ordinal, max_ordinal, batch)?);
19982062
}
19992063
Ok(group_values)
20002064
})

datafusion/sql/src/unparser/utils.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,14 @@ fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result<Option<&'a E
247247
)
248248
}
249249
Ordering::Greater => {
250-
Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1))
250+
if index < grouping_expr.len() + 1 {
251+
return internal_err!(
252+
"Tried to unproject column referring to internal grouping column"
253+
);
254+
}
255+
Ok(agg
256+
.aggr_expr
257+
.get(index - grouping_expr.len() - 1))
251258
}
252259
}
253260
} else {

datafusion/sqllogictest/test_files/group_by.slt

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5203,6 +5203,41 @@ NULL NULL 1
52035203
statement ok
52045204
drop table t;
52055205

5206+
# regression: duplicate grouping sets must not be collapsed into one
5207+
statement ok
5208+
create table duplicate_grouping_sets(deptno int, job varchar, sal int, comm int) as values
5209+
(10, 'CLERK', 1300, null),
5210+
(20, 'MANAGER', 3000, null);
5211+
5212+
query ITIIIII
5213+
select deptno, job, sal, sum(comm), grouping(deptno), grouping(job), grouping(sal)
5214+
from duplicate_grouping_sets
5215+
group by grouping sets ((deptno, job), (deptno, sal), (deptno, job))
5216+
order by deptno, job, sal, grouping(deptno), grouping(job), grouping(sal);
5217+
----
5218+
10 CLERK NULL NULL 0 0 1
5219+
10 CLERK NULL NULL 0 0 1
5220+
10 NULL 1300 NULL 0 1 0
5221+
20 MANAGER NULL NULL 0 0 1
5222+
20 MANAGER NULL NULL 0 0 1
5223+
20 NULL 3000 NULL 0 1 0
5224+
5225+
query ITII
5226+
select deptno, job, sal, grouping(deptno, job, sal)
5227+
from duplicate_grouping_sets
5228+
group by grouping sets ((deptno, job), (deptno, sal), (deptno, job))
5229+
order by deptno, job, sal, grouping(deptno, job, sal);
5230+
----
5231+
10 CLERK NULL 1
5232+
10 CLERK NULL 1
5233+
10 NULL 1300 2
5234+
20 MANAGER NULL 1
5235+
20 MANAGER NULL 1
5236+
20 NULL 3000 2
5237+
5238+
statement ok
5239+
drop table duplicate_grouping_sets;
5240+
52065241
# test multi group by for binary type without nulls
52075242
statement ok
52085243
create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, 0xb), (3, 0xb), (3, 0xb);

0 commit comments

Comments
 (0)