Skip to content

Commit 045f843

Browse files
committed
[SQL] Return one row for GROUP BY GROUPING SETS (()) over empty input
`GROUP BY GROUPING SETS (())` is a grand total, semantically identical to an aggregation with no `GROUP BY` clause. It was lowered to a grouped `Aggregate` over an `Expand` (grouping by `spark_grouping_id`), so over empty input it returned zero rows instead of one. The same defect affected the equivalent empty `GROUP BY CUBE()` and `GROUP BY ROLLUP()`, which also lower to a single empty grouping set. This lowers the single-empty-grouping-set case to a global `Aggregate` (no grouping expressions, no `Expand`) in `GroupingAnalyticsTransformer`, so it returns one (grand total) row over empty input, matching the `GROUP BY`-less form and the SQL standard. `grouping_id()` folds to the constant `0`, and `grouping()`/`grouping_id()` in `HAVING`/`ORDER BY` resolve against that constant. The fix lands in both the legacy fixed-point analyzer and the single-pass resolver, which share `GroupingAnalyticsTransformer`. The behavior is gated by an internal SQL config, `spark.sql.analyzer.lowerEmptyGroupingSetToGlobalAggregate.enabled` (default true). When set to false, lowering reverts to the legacy `Expand`-based form (zero rows over empty input). The flag gates all three decision points (the transformer lowering and the `grouping_id()` resolution in each analyzer) so the off state reproduces pre-fix behavior identically in both analyzers. Tested via golden cases in `grouping_set.sql` (empty-input grand total, `grouping_id()` in SELECT/HAVING/ORDER BY, `grouping()` rejection, non-empty input), flag-off coverage in `grouping_set_grand_total_disabled.sql`, the regenerated `group-analytics` golden file, and `ResolveGroupingAnalyticsSuite`. Co-authored-by: Isaac
1 parent 2539e18 commit 045f843

13 files changed

Lines changed: 672 additions & 69 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -840,11 +840,12 @@ class Analyzer(
840840
// We should make sure all expressions in condition have been resolved.
841841
case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved =>
842842
val groupingExprs = findGroupingExprs(child)
843-
// The unresolved grouping id will be resolved by ResolveReferences
843+
// For the grand total this is a resolved Literal(0); otherwise the unresolved
844+
// spark_grouping_id attribute is resolved by ResolveReferences.
844845
val newCond = GroupingAnalyticsTransformer.replaceGroupingFunction(
845846
expression = cond,
846847
groupByExpressions = groupingExprs,
847-
gid = VirtualColumn.groupingIdAttribute,
848+
gid = GroupingAnalyticsTransformer.groupingIdExpression(groupingExprs),
848849
newAlias = (child, name, qualifier) =>
849850
Alias(child, name.get)(qualifier = qualifier)
850851
)
@@ -854,8 +855,9 @@ class Analyzer(
854855
case s @ Sort(order, _, child, _)
855856
if order.exists(hasGroupingFunction) && order.forall(_.resolved) =>
856857
val groupingExprs = findGroupingExprs(child)
857-
val gid = VirtualColumn.groupingIdAttribute
858-
// The unresolved grouping id will be resolved by ResolveReferences
858+
val gid = GroupingAnalyticsTransformer.groupingIdExpression(groupingExprs)
859+
// For the grand total this is a resolved Literal(0); otherwise the unresolved
860+
// spark_grouping_id attribute is resolved by ResolveReferences.
859861
val newOrder = order.map { expression =>
860862
GroupingAnalyticsTransformer.replaceGroupingFunction(
861863
expression = expression,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala

Lines changed: 130 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2323
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan}
2424
import org.apache.spark.sql.catalyst.util.toPrettySQL
2525
import org.apache.spark.sql.errors.QueryCompilationErrors
26+
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.types.ByteType
2728

2829
/**
@@ -64,6 +65,12 @@ object GroupingAnalyticsTransformer extends SQLConfHelper with AliasHelper {
6465
* +- LocalRelation [col1#0]
6566
* }}}
6667
*
68+
* The grand-total-only case `GROUP BY GROUPING SETS (())` (and the equivalent empty `CUBE()` /
69+
* `ROLLUP()`) is the exception: instead of an [[Expand]] it is lowered to a global [[Aggregate]]
70+
* with no grouping expressions (see [[shouldLowerToGrandTotalAggregate]] and
71+
* [[constructGrandTotalAggregate]]), so it returns one row over empty input, like an aggregation
72+
* with no GROUP BY clause.
73+
*
6774
* @param newAlias Function to create new aliases, takes expression, optional name, and optional
6875
* qualifier
6976
* @param childOutput The output attributes of the child plan
@@ -81,34 +88,117 @@ object GroupingAnalyticsTransformer extends SQLConfHelper with AliasHelper {
8188
child: LogicalPlan,
8289
aggregationExpressions: Seq[NamedExpression]): Aggregate = {
8390

84-
val groupByAliases = constructGroupByAlias(newAlias, groupByExpressions)
91+
if (shouldLowerToGrandTotalAggregate(groupByExpressions, selectedGroupByExpressions)) {
92+
constructGrandTotalAggregate(newAlias, aggregationExpressions, child)
93+
} else {
94+
val groupByAliases = constructGroupByAlias(newAlias, groupByExpressions)
8595

86-
val gid = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
87-
val expand = constructExpand(
88-
selectedGroupByExpressions = selectedGroupByExpressions,
89-
child = child,
90-
groupByAliases = groupByAliases,
91-
gid = gid,
92-
childOutput = childOutput
93-
)
94-
val groupingAttributes = expand.output.drop(childOutput.length)
96+
val gid = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
97+
val expand = constructExpand(
98+
selectedGroupByExpressions = selectedGroupByExpressions,
99+
child = child,
100+
groupByAliases = groupByAliases,
101+
gid = gid,
102+
childOutput = childOutput
103+
)
104+
val groupingAttributes = expand.output.drop(childOutput.length)
95105

96-
val aggregations = constructAggregateExpressions(
97-
newAlias = newAlias,
98-
groupByExpressions = groupByExpressions,
99-
aggregations = aggregationExpressions,
100-
groupByAliases = groupByAliases,
101-
groupingAttributes = groupingAttributes,
102-
gid = gid
103-
)
106+
val aggregations = constructAggregateExpressions(
107+
newAlias = newAlias,
108+
groupByExpressions = groupByExpressions,
109+
aggregations = aggregationExpressions,
110+
groupByAliases = groupByAliases,
111+
groupingAttributes = groupingAttributes,
112+
gid = gid
113+
)
114+
115+
Aggregate(
116+
groupingExpressions = groupingAttributes,
117+
aggregateExpressions = aggregations,
118+
child = expand
119+
)
120+
}
121+
}
104122

105-
val aggregate = Aggregate(
106-
groupingExpressions = groupingAttributes,
123+
/**
124+
* Whether a grouping-set spec is the grand-total-only case `GROUP BY GROUPING SETS (())` (and
125+
* the equivalent empty `CUBE()`/`ROLLUP()`) that [[apply]] lowers to a global [[Aggregate]]:
126+
* no leading group-by expressions and a single empty grouping set, with
127+
* [[SQLConf.LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE]] enabled. This is the
128+
* pre-lowering decision; [[isLoweredToGrandTotalAggregate]] detects the same case post-lowering.
129+
*/
130+
def shouldLowerToGrandTotalAggregate(
131+
groupByExpressions: Seq[Expression],
132+
selectedGroupByExpressions: Seq[Seq[Expression]]): Boolean = {
133+
conf.getConf(SQLConf.LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE) &&
134+
groupByExpressions.isEmpty &&
135+
selectedGroupByExpressions.length == 1 &&
136+
selectedGroupByExpressions.head.isEmpty
137+
}
138+
139+
/**
140+
* Whether a lowered [[Aggregate]] is the grand total produced by
141+
* [[shouldLowerToGrandTotalAggregate]]: its resolved grouping expressions (without the
142+
* `spark_grouping_id` key, as returned by [[collectGroupingExpressions]]) are empty, with the
143+
* flag enabled. The flag is part of the check because with it off the same grand total is
144+
* lowered via [[Expand]], whose `spark_grouping_id` key must still be referenced.
145+
*
146+
* This keys purely on the collected grouping expressions being empty, so it also matches a
147+
* value-equivalent all-empty multi-set [[Expand]] aggregate (e.g. `GROUP BY GROUPING SETS ((),
148+
* ())`), whose only grouping key is `spark_grouping_id` and whose every row's grouping id is
149+
* likewise 0 (there are no group-by columns). Folding grouping_id() to the constant 0 is correct
150+
* for both.
151+
*/
152+
def isLoweredToGrandTotalAggregate(groupingExpressions: Seq[Expression]): Boolean = {
153+
conf.getConf(SQLConf.LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE) &&
154+
groupingExpressions.isEmpty
155+
}
156+
157+
/**
158+
* The grouping id to substitute for grouping()/grouping_id() over a lowered grouping-analytics
159+
* [[Aggregate]]: the constant 0 for a grand total ([[isLoweredToGrandTotalAggregate]]), otherwise
160+
* the unresolved `spark_grouping_id` attribute produced by [[Expand]].
161+
*/
162+
def groupingIdExpression(groupingExpressions: Seq[Expression]): Expression = {
163+
if (isLoweredToGrandTotalAggregate(groupingExpressions)) {
164+
Literal.default(GroupingID.dataType)
165+
} else {
166+
VirtualColumn.groupingIdAttribute
167+
}
168+
}
169+
170+
/**
171+
* Build a global [[Aggregate]] (with no grouping expressions) for the grand-total-only case
172+
* `GROUP BY GROUPING SETS (())`.
173+
*
174+
* A single empty grouping set is a grand total, semantically identical to an aggregation with no
175+
* GROUP BY clause. Lowering it via [[Expand]] would group by `spark_grouping_id` and so emit zero
176+
* rows over empty input instead of one; a global [[Aggregate]] returns one (grand total) row,
177+
* matching the GROUP BY-less form and the SQL standard.
178+
*
179+
* Any [[GroupingID]] in the aggregation expressions evaluates to the constant `0` (the only
180+
* active grouping set is the empty one), and [[Grouping]] over a non-grouping column is rejected,
181+
* both reusing [[replaceGroupingFunction]] with a literal grouping id in place of the
182+
* Expand-produced `spark_grouping_id` attribute.
183+
*/
184+
private def constructGrandTotalAggregate(
185+
newAlias: (Expression, Option[String], Seq[String]) => Alias,
186+
aggregationExpressions: Seq[NamedExpression],
187+
child: LogicalPlan): Aggregate = {
188+
val groupingId = Literal.default(GroupingID.dataType)
189+
val aggregations = aggregationExpressions.map { expression =>
190+
replaceGroupingFunction(
191+
expression = expression,
192+
groupByExpressions = Seq.empty,
193+
gid = groupingId,
194+
newAlias = newAlias
195+
).asInstanceOf[NamedExpression]
196+
}
197+
Aggregate(
198+
groupingExpressions = Seq.empty,
107199
aggregateExpressions = aggregations,
108-
child = expand
200+
child = child
109201
)
110-
111-
aggregate
112202
}
113203

114204
/**
@@ -153,19 +243,27 @@ object GroupingAnalyticsTransformer extends SQLConfHelper with AliasHelper {
153243
/**
154244
* Collect the last grouping expression since the provided [[Aggregate]] should have grouping id
155245
* as the last grouping key.
246+
*
247+
* A grand-total `GROUP BY GROUPING SETS (())` is lowered to a global [[Aggregate]] with no
248+
* grouping expressions (see [[apply]]), so there is no `spark_grouping_id` key and no grouping
249+
* columns; return an empty sequence for it.
156250
*/
157251
def collectGroupingExpressions(aggregate: Aggregate): Seq[Expression] = {
158-
val gid = aggregate.groupingExpressions.last
159-
gid match {
160-
case attributeReference: AttributeReference =>
161-
if (attributeReference.name != VirtualColumn.groupingIdName) {
252+
if (aggregate.groupingExpressions.isEmpty) {
253+
Seq.empty
254+
} else {
255+
val gid = aggregate.groupingExpressions.last
256+
gid match {
257+
case attributeReference: AttributeReference =>
258+
if (attributeReference.name != VirtualColumn.groupingIdName) {
259+
throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
260+
}
261+
case _ =>
162262
throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
163-
}
164-
case _ =>
165-
throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
166-
}
263+
}
167264

168-
aggregate.groupingExpressions.take(aggregate.groupingExpressions.length - 1)
265+
aggregate.groupingExpressions.take(aggregate.groupingExpressions.length - 1)
266+
}
169267
}
170268

171269
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@ package org.apache.spark.sql.catalyst.analysis.resolver
1919

2020
import org.apache.spark.sql.catalyst.analysis.GroupingAnalyticsTransformer
2121
import org.apache.spark.sql.catalyst.expressions.{
22-
AttributeReference,
2322
BaseGroupingSets,
2423
Expression,
2524
GroupingAnalyticsExtractor,
26-
SortOrder,
27-
VirtualColumn
25+
SortOrder
2826
}
2927
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Sort}
3028
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -111,7 +109,9 @@ class GroupingAnalyticsResolver(resolver: Resolver, expressionResolver: Expressi
111109
* - If there is, collect grouping expressions from it using
112110
* [[GroupingAnalyticsTransformer.collectGroupingExpressions]].
113111
* - If there isn't, throw `groupingMustWithGroupingSetsOrCubeOrRollupError` exception.
114-
* 2. Create a grouping ID attribute and resolve it using [[ExpressionResolver]].
112+
* 2. Compute the grouping id via [[GroupingAnalyticsTransformer.groupingIdExpression]] (the
113+
* constant 0 for a grand total, otherwise the `spark_grouping_id` attribute) and resolve it
114+
* using [[ExpressionResolver]].
115115
* 3. Replace [[SortOrder]] expressions using
116116
* [[GroupingAnalyticsTransformer.replaceGroupingFunction]] (see its scala doc for more
117117
* details).
@@ -128,10 +128,9 @@ class GroupingAnalyticsResolver(resolver: Resolver, expressionResolver: Expressi
128128
scopes.current.baseAggregate.get
129129
)
130130

131-
val groupingId = VirtualColumn.groupingIdAttribute
132-
val resolvedGroupingId = expressionResolver
133-
.resolveExpressionTreeInOperator(groupingId, sort)
134-
.asInstanceOf[AttributeReference]
131+
val groupingId = GroupingAnalyticsTransformer.groupingIdExpression(groupingExpressions)
132+
val resolvedGroupingId =
133+
expressionResolver.resolveExpressionTreeInOperator(groupingId, sort)
135134

136135
val orderExpressionsWithGroupingAnalytics = orderExpressions.map { orderExpression =>
137136
GroupingAnalyticsTransformer

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,19 @@ object SQLConf {
292292
.booleanConf
293293
.createWithDefault(true)
294294

295+
val LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE =
296+
buildConf("spark.sql.analyzer.lowerEmptyGroupingSetToGlobalAggregate.enabled")
297+
.internal()
298+
.version("4.3.0")
299+
.doc(
300+
"When true, a grand-total GROUP BY GROUPING SETS (()) (and the equivalent empty " +
301+
"CUBE() / ROLLUP()) is lowered to a global aggregate during analysis, so it returns " +
302+
"one row over empty input, matching an aggregation with no GROUP BY clause. When false, " +
303+
"falls back to the legacy Expand-based lowering that returns no rows over empty input.")
304+
.withBindingPolicy(ConfigBindingPolicy.SESSION)
305+
.booleanConf
306+
.createWithDefault(true)
307+
295308
val ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS =
296309
buildConf("spark.sql.analyzer.uniqueNecessaryMetadataColumns")
297310
.internal()

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,10 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
130130
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
131131
checkAnalysis(originalPlan, expected)
132132

133+
// CUBE() over no columns is a single empty grouping set, i.e. a grand total, so it lowers
134+
// to a global Aggregate (no Expand) and returns one row even over empty input.
133135
val originalPlan2 = Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1)
134-
val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")),
135-
Expand(
136-
Seq(Seq(a, b, c, 0L)),
137-
Seq(a, b, c, gid),
138-
Project(Seq(a, b, c), r1)))
136+
val expected2 = Aggregate(Seq.empty[Expression], Seq(count(c).as("count(c)")), r1)
139137
checkAnalysis(originalPlan2, expected2)
140138
}
141139

@@ -149,12 +147,10 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
149147
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
150148
checkAnalysis(originalPlan, expected)
151149

150+
// ROLLUP() over no columns is a single empty grouping set, i.e. a grand total, so it lowers
151+
// to a global Aggregate (no Expand) and returns one row even over empty input.
152152
val originalPlan2 = Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1)
153-
val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")),
154-
Expand(
155-
Seq(Seq(a, b, c, 0L)),
156-
Seq(a, b, c, gid),
157-
Project(Seq(a, b, c), r1)))
153+
val expected2 = Aggregate(Seq.empty[Expression], Seq(count(c).as("count(c)")), r1)
158154
checkAnalysis(originalPlan2, expected2)
159155
}
160156

sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,15 +1931,13 @@ SELECT SUM(b) AS s
19311931
FROM aggSrc
19321932
GROUP BY GROUPING SETS (())
19331933
-- !query analysis
1934-
Aggregate [spark_grouping_id#xL], [sum(b#x) AS s#xL]
1935-
+- Expand [[a#x, b#x, 0]], [a#x, b#x, spark_grouping_id#xL]
1936-
+- Project [a#x, b#x]
1937-
+- SubqueryAlias aggsrc
1938-
+- View (`aggSrc`, [a#x, b#x])
1939-
+- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x]
1940-
+- Project [a#x, b#x]
1941-
+- SubqueryAlias aggSrc
1942-
+- LocalRelation [a#x, b#x]
1934+
Aggregate [sum(b#x) AS s#xL]
1935+
+- SubqueryAlias aggsrc
1936+
+- View (`aggSrc`, [a#x, b#x])
1937+
+- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x]
1938+
+- Project [a#x, b#x]
1939+
+- SubqueryAlias aggSrc
1940+
+- LocalRelation [a#x, b#x]
19431941

19441942

19451943
-- !query

0 commit comments

Comments
 (0)