Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -840,11 +840,12 @@ class Analyzer(
// We should make sure all expressions in condition have been resolved.
case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved =>
val groupingExprs = findGroupingExprs(child)
// The unresolved grouping id will be resolved by ResolveReferences
// For the grand total this is a resolved Literal(0); otherwise the unresolved
// spark_grouping_id attribute is resolved by ResolveReferences.
val newCond = GroupingAnalyticsTransformer.replaceGroupingFunction(
expression = cond,
groupByExpressions = groupingExprs,
gid = VirtualColumn.groupingIdAttribute,
gid = GroupingAnalyticsTransformer.groupingIdExpression(groupingExprs),
newAlias = (child, name, qualifier) =>
Alias(child, name.get)(qualifier = qualifier)
)
Expand All @@ -854,8 +855,9 @@ class Analyzer(
case s @ Sort(order, _, child, _)
if order.exists(hasGroupingFunction) && order.forall(_.resolved) =>
val groupingExprs = findGroupingExprs(child)
val gid = VirtualColumn.groupingIdAttribute
// The unresolved grouping id will be resolved by ResolveReferences
val gid = GroupingAnalyticsTransformer.groupingIdExpression(groupingExprs)
// For the grand total this is a resolved Literal(0); otherwise the unresolved
// spark_grouping_id attribute is resolved by ResolveReferences.
val newOrder = order.map { expression =>
GroupingAnalyticsTransformer.replaceGroupingFunction(
expression = expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.ByteType

/**
Expand Down Expand Up @@ -64,6 +65,12 @@ object GroupingAnalyticsTransformer extends SQLConfHelper with AliasHelper {
* +- LocalRelation [col1#0]
* }}}
*
* The grand-total-only case `GROUP BY GROUPING SETS (())` (and the equivalent empty `CUBE()` /
* `ROLLUP()`) is the exception: instead of an [[Expand]] it is lowered to a global [[Aggregate]]
* with no grouping expressions (see [[shouldLowerToGrandTotalAggregate]] and
* [[constructGrandTotalAggregate]]), so it returns one row over empty input, like an aggregation
* with no GROUP BY clause.
*
* @param newAlias Function to create new aliases, takes expression, optional name, and optional
* qualifier
* @param childOutput The output attributes of the child plan
Expand All @@ -81,34 +88,117 @@ object GroupingAnalyticsTransformer extends SQLConfHelper with AliasHelper {
child: LogicalPlan,
aggregationExpressions: Seq[NamedExpression]): Aggregate = {

val groupByAliases = constructGroupByAlias(newAlias, groupByExpressions)
if (shouldLowerToGrandTotalAggregate(groupByExpressions, selectedGroupByExpressions)) {
constructGrandTotalAggregate(newAlias, aggregationExpressions, child)
} else {
val groupByAliases = constructGroupByAlias(newAlias, groupByExpressions)

val gid = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
val expand = constructExpand(
selectedGroupByExpressions = selectedGroupByExpressions,
child = child,
groupByAliases = groupByAliases,
gid = gid,
childOutput = childOutput
)
val groupingAttributes = expand.output.drop(childOutput.length)
val gid = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
val expand = constructExpand(
selectedGroupByExpressions = selectedGroupByExpressions,
child = child,
groupByAliases = groupByAliases,
gid = gid,
childOutput = childOutput
)
val groupingAttributes = expand.output.drop(childOutput.length)

val aggregations = constructAggregateExpressions(
newAlias = newAlias,
groupByExpressions = groupByExpressions,
aggregations = aggregationExpressions,
groupByAliases = groupByAliases,
groupingAttributes = groupingAttributes,
gid = gid
)
val aggregations = constructAggregateExpressions(
newAlias = newAlias,
groupByExpressions = groupByExpressions,
aggregations = aggregationExpressions,
groupByAliases = groupByAliases,
groupingAttributes = groupingAttributes,
gid = gid
)

Aggregate(
groupingExpressions = groupingAttributes,
aggregateExpressions = aggregations,
child = expand
)
}
}

val aggregate = Aggregate(
groupingExpressions = groupingAttributes,
/**
* Whether a grouping-set spec is the grand-total-only case `GROUP BY GROUPING SETS (())` (and
* the equivalent empty `CUBE()`/`ROLLUP()`) that [[apply]] lowers to a global [[Aggregate]]:
* no leading group-by expressions and a single empty grouping set, with
* [[SQLConf.LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE]] enabled. This is the
* pre-lowering decision; [[isLoweredToGrandTotalAggregate]] detects the same case post-lowering.
*/
def shouldLowerToGrandTotalAggregate(
groupByExpressions: Seq[Expression],
selectedGroupByExpressions: Seq[Seq[Expression]]): Boolean = {
conf.getConf(SQLConf.LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE) &&
groupByExpressions.isEmpty &&
selectedGroupByExpressions.length == 1 &&
selectedGroupByExpressions.head.isEmpty
}

/**
* Whether a lowered [[Aggregate]] is the grand total produced by
* [[shouldLowerToGrandTotalAggregate]]: its resolved grouping expressions (without the
* `spark_grouping_id` key, as returned by [[collectGroupingExpressions]]) are empty, with the
* flag enabled. The flag is part of the check because with it off the same grand total is
* lowered via [[Expand]], whose `spark_grouping_id` key must still be referenced.
*
* This keys purely on the collected grouping expressions being empty, so it also matches a
* value-equivalent all-empty multi-set [[Expand]] aggregate (e.g. `GROUP BY GROUPING SETS ((),
* ())`), whose only grouping key is `spark_grouping_id` and whose every row's grouping id is
* likewise 0 (there are no group-by columns). Folding grouping_id() to the constant 0 is correct
* for both.
*/
def isLoweredToGrandTotalAggregate(groupingExpressions: Seq[Expression]): Boolean = {
conf.getConf(SQLConf.LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE) &&
groupingExpressions.isEmpty
}

/**
* The grouping id to substitute for grouping()/grouping_id() over a lowered grouping-analytics
* [[Aggregate]]: the constant 0 for a grand total ([[isLoweredToGrandTotalAggregate]]), otherwise
* the unresolved `spark_grouping_id` attribute produced by [[Expand]].
*/
def groupingIdExpression(groupingExpressions: Seq[Expression]): Expression = {
if (isLoweredToGrandTotalAggregate(groupingExpressions)) {
Literal.default(GroupingID.dataType)
} else {
VirtualColumn.groupingIdAttribute
}
}

/**
* Build a global [[Aggregate]] (with no grouping expressions) for the grand-total-only case
* `GROUP BY GROUPING SETS (())`.
*
* A single empty grouping set is a grand total, semantically identical to an aggregation with no
* GROUP BY clause. Lowering it via [[Expand]] would group by `spark_grouping_id` and so emit zero
* rows over empty input instead of one; a global [[Aggregate]] returns one (grand total) row,
* matching the GROUP BY-less form and the SQL standard.
*
* Any [[GroupingID]] in the aggregation expressions evaluates to the constant `0` (the only
* active grouping set is the empty one), and [[Grouping]] over a non-grouping column is rejected,
* both reusing [[replaceGroupingFunction]] with a literal grouping id in place of the
* Expand-produced `spark_grouping_id` attribute.
*/
private def constructGrandTotalAggregate(
newAlias: (Expression, Option[String], Seq[String]) => Alias,
aggregationExpressions: Seq[NamedExpression],
child: LogicalPlan): Aggregate = {
val groupingId = Literal.default(GroupingID.dataType)
val aggregations = aggregationExpressions.map { expression =>
replaceGroupingFunction(
expression = expression,
groupByExpressions = Seq.empty,
gid = groupingId,
newAlias = newAlias
).asInstanceOf[NamedExpression]
}
Aggregate(
groupingExpressions = Seq.empty,
aggregateExpressions = aggregations,
child = expand
child = child
)

aggregate
}

/**
Expand Down Expand Up @@ -153,19 +243,27 @@ object GroupingAnalyticsTransformer extends SQLConfHelper with AliasHelper {
/**
* Collect the last grouping expression since the provided [[Aggregate]] should have grouping id
* as the last grouping key.
*
* A grand-total `GROUP BY GROUPING SETS (())` is lowered to a global [[Aggregate]] with no
* grouping expressions (see [[apply]]), so there is no `spark_grouping_id` key and no grouping
* columns; return an empty sequence for it.
*/
def collectGroupingExpressions(aggregate: Aggregate): Seq[Expression] = {
val gid = aggregate.groupingExpressions.last
gid match {
case attributeReference: AttributeReference =>
if (attributeReference.name != VirtualColumn.groupingIdName) {
if (aggregate.groupingExpressions.isEmpty) {
Seq.empty
} else {
val gid = aggregate.groupingExpressions.last
gid match {
case attributeReference: AttributeReference =>
if (attributeReference.name != VirtualColumn.groupingIdName) {
throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
}
case _ =>
throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
}
case _ =>
throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
}
}

aggregate.groupingExpressions.take(aggregate.groupingExpressions.length - 1)
aggregate.groupingExpressions.take(aggregate.groupingExpressions.length - 1)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ package org.apache.spark.sql.catalyst.analysis.resolver

import org.apache.spark.sql.catalyst.analysis.GroupingAnalyticsTransformer
import org.apache.spark.sql.catalyst.expressions.{
AttributeReference,
BaseGroupingSets,
Expression,
GroupingAnalyticsExtractor,
SortOrder,
VirtualColumn
SortOrder
}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Sort}
import org.apache.spark.sql.errors.QueryCompilationErrors
Expand Down Expand Up @@ -111,7 +109,9 @@ class GroupingAnalyticsResolver(resolver: Resolver, expressionResolver: Expressi
* - If there is, collect grouping expressions from it using
* [[GroupingAnalyticsTransformer.collectGroupingExpressions]].
* - If there isn't, throw `groupingMustWithGroupingSetsOrCubeOrRollupError` exception.
* 2. Create a grouping ID attribute and resolve it using [[ExpressionResolver]].
* 2. Compute the grouping id via [[GroupingAnalyticsTransformer.groupingIdExpression]] (the
* constant 0 for a grand total, otherwise the `spark_grouping_id` attribute) and resolve it
* using [[ExpressionResolver]].
* 3. Replace [[SortOrder]] expressions using
* [[GroupingAnalyticsTransformer.replaceGroupingFunction]] (see its scala doc for more
* details).
Expand All @@ -128,10 +128,9 @@ class GroupingAnalyticsResolver(resolver: Resolver, expressionResolver: Expressi
scopes.current.baseAggregate.get
)

val groupingId = VirtualColumn.groupingIdAttribute
val resolvedGroupingId = expressionResolver
.resolveExpressionTreeInOperator(groupingId, sort)
.asInstanceOf[AttributeReference]
val groupingId = GroupingAnalyticsTransformer.groupingIdExpression(groupingExpressions)
val resolvedGroupingId =
expressionResolver.resolveExpressionTreeInOperator(groupingId, sort)

val orderExpressionsWithGroupingAnalytics = orderExpressions.map { orderExpression =>
GroupingAnalyticsTransformer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,19 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE =
buildConf("spark.sql.analyzer.lowerEmptyGroupingSetToGlobalAggregate.enabled")
.internal()
.version("4.3.0")
.doc(
"When true, a grand-total GROUP BY GROUPING SETS (()) (and the equivalent empty " +
"CUBE() / ROLLUP()) is lowered to a global aggregate during analysis, so it returns " +
"one row over empty input, matching an aggregation with no GROUP BY clause. When false, " +
"falls back to the legacy Expand-based lowering that returns no rows over empty input.")
.withBindingPolicy(ConfigBindingPolicy.SESSION)
.booleanConf
.createWithDefault(true)

val ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS =
buildConf("spark.sql.analyzer.uniqueNecessaryMetadataColumns")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,10 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
checkAnalysis(originalPlan, expected)

// CUBE() over no columns is a single empty grouping set, i.e. a grand total, so it lowers
// to a global Aggregate (no Expand) and returns one row even over empty input.
val originalPlan2 = Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1)
val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")),
Expand(
Seq(Seq(a, b, c, 0L)),
Seq(a, b, c, gid),
Project(Seq(a, b, c), r1)))
val expected2 = Aggregate(Seq.empty[Expression], Seq(count(c).as("count(c)")), r1)
checkAnalysis(originalPlan2, expected2)
}

Expand All @@ -149,12 +147,10 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
checkAnalysis(originalPlan, expected)

// ROLLUP() over no columns is a single empty grouping set, i.e. a grand total, so it lowers
// to a global Aggregate (no Expand) and returns one row even over empty input.
val originalPlan2 = Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1)
val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")),
Expand(
Seq(Seq(a, b, c, 0L)),
Seq(a, b, c, gid),
Project(Seq(a, b, c), r1)))
val expected2 = Aggregate(Seq.empty[Expression], Seq(count(c).as("count(c)")), r1)
checkAnalysis(originalPlan2, expected2)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1931,15 +1931,13 @@ SELECT SUM(b) AS s
FROM aggSrc
GROUP BY GROUPING SETS (())
-- !query analysis
Aggregate [spark_grouping_id#xL], [sum(b#x) AS s#xL]
+- Expand [[a#x, b#x, 0]], [a#x, b#x, spark_grouping_id#xL]
+- Project [a#x, b#x]
+- SubqueryAlias aggsrc
+- View (`aggSrc`, [a#x, b#x])
+- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x]
+- Project [a#x, b#x]
+- SubqueryAlias aggSrc
+- LocalRelation [a#x, b#x]
Aggregate [sum(b#x) AS s#xL]
+- SubqueryAlias aggsrc
+- View (`aggSrc`, [a#x, b#x])
+- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x]
+- Project [a#x, b#x]
+- SubqueryAlias aggSrc
+- LocalRelation [a#x, b#x]


-- !query
Expand Down
Loading