diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6cd985fd01fa0..45f9d7e044550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -840,11 +840,12 @@ class Analyzer( // We should make sure all expressions in condition have been resolved. case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved => val groupingExprs = findGroupingExprs(child) - // The unresolved grouping id will be resolved by ResolveReferences + // For the grand total this is a resolved Literal(0); otherwise the unresolved + // spark_grouping_id attribute is resolved by ResolveReferences. val newCond = GroupingAnalyticsTransformer.replaceGroupingFunction( expression = cond, groupByExpressions = groupingExprs, - gid = VirtualColumn.groupingIdAttribute, + gid = GroupingAnalyticsTransformer.groupingIdExpression(groupingExprs), newAlias = (child, name, qualifier) => Alias(child, name.get)(qualifier = qualifier) ) @@ -854,8 +855,9 @@ class Analyzer( case s @ Sort(order, _, child, _) if order.exists(hasGroupingFunction) && order.forall(_.resolved) => val groupingExprs = findGroupingExprs(child) - val gid = VirtualColumn.groupingIdAttribute - // The unresolved grouping id will be resolved by ResolveReferences + val gid = GroupingAnalyticsTransformer.groupingIdExpression(groupingExprs) + // For the grand total this is a resolved Literal(0); otherwise the unresolved + // spark_grouping_id attribute is resolved by ResolveReferences. val newOrder = order.map { expression => GroupingAnalyticsTransformer.replaceGroupingFunction( expression = expression, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala index ade9e1d2faf9f..898dc3b9befdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ByteType /** @@ -64,6 +65,12 @@ object GroupingAnalyticsTransformer extends SQLConfHelper with AliasHelper { * +- LocalRelation [col1#0] * }}} * + * The grand-total-only case `GROUP BY GROUPING SETS (())` (and the equivalent empty `CUBE()` / + * `ROLLUP()`) is the exception: instead of an [[Expand]] it is lowered to a global [[Aggregate]] + * with no grouping expressions (see [[shouldLowerToGrandTotalAggregate]] and + * [[constructGrandTotalAggregate]]), so it returns one row over empty input, like an aggregation + * with no GROUP BY clause. + * * @param newAlias Function to create new aliases, takes expression, optional name, and optional * qualifier * @param childOutput The output attributes of the child plan @@ -81,34 +88,117 @@ object GroupingAnalyticsTransformer extends SQLConfHelper with AliasHelper { child: LogicalPlan, aggregationExpressions: Seq[NamedExpression]): Aggregate = { - val groupByAliases = constructGroupByAlias(newAlias, groupByExpressions) + if (shouldLowerToGrandTotalAggregate(groupByExpressions, selectedGroupByExpressions)) { + constructGrandTotalAggregate(newAlias, aggregationExpressions, child) + } else { + val groupByAliases = constructGroupByAlias(newAlias, groupByExpressions) - val gid = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)() - val expand = constructExpand( - selectedGroupByExpressions = selectedGroupByExpressions, - child = child, - groupByAliases = groupByAliases, - gid = gid, - childOutput = childOutput - ) - val groupingAttributes = expand.output.drop(childOutput.length) + val gid = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)() + val expand = constructExpand( + selectedGroupByExpressions = selectedGroupByExpressions, + child = child, + groupByAliases = groupByAliases, + gid = gid, + childOutput = childOutput + ) + val groupingAttributes = expand.output.drop(childOutput.length) - val aggregations = constructAggregateExpressions( - newAlias = newAlias, - groupByExpressions = groupByExpressions, - aggregations = aggregationExpressions, - groupByAliases = groupByAliases, - groupingAttributes = groupingAttributes, - gid = gid - ) + val aggregations = constructAggregateExpressions( + newAlias = newAlias, + groupByExpressions = groupByExpressions, + aggregations = aggregationExpressions, + groupByAliases = groupByAliases, + groupingAttributes = groupingAttributes, + gid = gid + ) + + Aggregate( + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregations, + child = expand + ) + } + } - val aggregate = Aggregate( - groupingExpressions = groupingAttributes, + /** + * Whether a grouping-set spec is the grand-total-only case `GROUP BY GROUPING SETS (())` (and + * the equivalent empty `CUBE()`/`ROLLUP()`) that [[apply]] lowers to a global [[Aggregate]]: + * no leading group-by expressions and a single empty grouping set, with + * [[SQLConf.LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE]] enabled. This is the + * pre-lowering decision; [[isLoweredToGrandTotalAggregate]] detects the same case post-lowering. + */ + def shouldLowerToGrandTotalAggregate( + groupByExpressions: Seq[Expression], + selectedGroupByExpressions: Seq[Seq[Expression]]): Boolean = { + conf.getConf(SQLConf.LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE) && + groupByExpressions.isEmpty && + selectedGroupByExpressions.length == 1 && + selectedGroupByExpressions.head.isEmpty + } + + /** + * Whether a lowered [[Aggregate]] is the grand total produced by + * [[shouldLowerToGrandTotalAggregate]]: its resolved grouping expressions (without the + * `spark_grouping_id` key, as returned by [[collectGroupingExpressions]]) are empty, with the + * flag enabled. The flag is part of the check because with it off the same grand total is + * lowered via [[Expand]], whose `spark_grouping_id` key must still be referenced. + * + * This keys purely on the collected grouping expressions being empty, so it also matches a + * value-equivalent all-empty multi-set [[Expand]] aggregate (e.g. `GROUP BY GROUPING SETS ((), + * ())`), whose only grouping key is `spark_grouping_id` and whose every row's grouping id is + * likewise 0 (there are no group-by columns). Folding grouping_id() to the constant 0 is correct + * for both. + */ + def isLoweredToGrandTotalAggregate(groupingExpressions: Seq[Expression]): Boolean = { + conf.getConf(SQLConf.LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE) && + groupingExpressions.isEmpty + } + + /** + * The grouping id to substitute for grouping()/grouping_id() over a lowered grouping-analytics + * [[Aggregate]]: the constant 0 for a grand total ([[isLoweredToGrandTotalAggregate]]), otherwise + * the unresolved `spark_grouping_id` attribute produced by [[Expand]]. + */ + def groupingIdExpression(groupingExpressions: Seq[Expression]): Expression = { + if (isLoweredToGrandTotalAggregate(groupingExpressions)) { + Literal.default(GroupingID.dataType) + } else { + VirtualColumn.groupingIdAttribute + } + } + + /** + * Build a global [[Aggregate]] (with no grouping expressions) for the grand-total-only case + * `GROUP BY GROUPING SETS (())`. + * + * A single empty grouping set is a grand total, semantically identical to an aggregation with no + * GROUP BY clause. Lowering it via [[Expand]] would group by `spark_grouping_id` and so emit zero + * rows over empty input instead of one; a global [[Aggregate]] returns one (grand total) row, + * matching the GROUP BY-less form and the SQL standard. + * + * Any [[GroupingID]] in the aggregation expressions evaluates to the constant `0` (the only + * active grouping set is the empty one), and [[Grouping]] over a non-grouping column is rejected, + * both reusing [[replaceGroupingFunction]] with a literal grouping id in place of the + * Expand-produced `spark_grouping_id` attribute. + */ + private def constructGrandTotalAggregate( + newAlias: (Expression, Option[String], Seq[String]) => Alias, + aggregationExpressions: Seq[NamedExpression], + child: LogicalPlan): Aggregate = { + val groupingId = Literal.default(GroupingID.dataType) + val aggregations = aggregationExpressions.map { expression => + replaceGroupingFunction( + expression = expression, + groupByExpressions = Seq.empty, + gid = groupingId, + newAlias = newAlias + ).asInstanceOf[NamedExpression] + } + Aggregate( + groupingExpressions = Seq.empty, aggregateExpressions = aggregations, - child = expand + child = child ) - - aggregate } /** @@ -153,19 +243,27 @@ object GroupingAnalyticsTransformer extends SQLConfHelper with AliasHelper { /** * Collect the last grouping expression since the provided [[Aggregate]] should have grouping id * as the last grouping key. + * + * A grand-total `GROUP BY GROUPING SETS (())` is lowered to a global [[Aggregate]] with no + * grouping expressions (see [[apply]]), so there is no `spark_grouping_id` key and no grouping + * columns; return an empty sequence for it. */ def collectGroupingExpressions(aggregate: Aggregate): Seq[Expression] = { - val gid = aggregate.groupingExpressions.last - gid match { - case attributeReference: AttributeReference => - if (attributeReference.name != VirtualColumn.groupingIdName) { + if (aggregate.groupingExpressions.isEmpty) { + Seq.empty + } else { + val gid = aggregate.groupingExpressions.last + gid match { + case attributeReference: AttributeReference => + if (attributeReference.name != VirtualColumn.groupingIdName) { + throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError() + } + case _ => throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError() - } - case _ => - throw QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError() - } + } - aggregate.groupingExpressions.take(aggregate.groupingExpressions.length - 1) + aggregate.groupingExpressions.take(aggregate.groupingExpressions.length - 1) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala index af92e2cb453ea..3f6b154eae65e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAnalyticsResolver.scala @@ -19,12 +19,10 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.catalyst.analysis.GroupingAnalyticsTransformer import org.apache.spark.sql.catalyst.expressions.{ - AttributeReference, BaseGroupingSets, Expression, GroupingAnalyticsExtractor, - SortOrder, - VirtualColumn + SortOrder } import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, Sort} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -111,7 +109,9 @@ class GroupingAnalyticsResolver(resolver: Resolver, expressionResolver: Expressi * - If there is, collect grouping expressions from it using * [[GroupingAnalyticsTransformer.collectGroupingExpressions]]. * - If there isn't, throw `groupingMustWithGroupingSetsOrCubeOrRollupError` exception. - * 2. Create a grouping ID attribute and resolve it using [[ExpressionResolver]]. + * 2. Compute the grouping id via [[GroupingAnalyticsTransformer.groupingIdExpression]] (the + * constant 0 for a grand total, otherwise the `spark_grouping_id` attribute) and resolve it + * using [[ExpressionResolver]]. * 3. Replace [[SortOrder]] expressions using * [[GroupingAnalyticsTransformer.replaceGroupingFunction]] (see its scala doc for more * details). @@ -128,10 +128,9 @@ class GroupingAnalyticsResolver(resolver: Resolver, expressionResolver: Expressi scopes.current.baseAggregate.get ) - val groupingId = VirtualColumn.groupingIdAttribute - val resolvedGroupingId = expressionResolver - .resolveExpressionTreeInOperator(groupingId, sort) - .asInstanceOf[AttributeReference] + val groupingId = GroupingAnalyticsTransformer.groupingIdExpression(groupingExpressions) + val resolvedGroupingId = + expressionResolver.resolveExpressionTreeInOperator(groupingId, sort) val orderExpressionsWithGroupingAnalytics = orderExpressions.map { orderExpression => GroupingAnalyticsTransformer diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6776f88ed1ef8..d57c34b919a84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -292,6 +292,19 @@ object SQLConf { .booleanConf .createWithDefault(true) + val LOWER_EMPTY_GROUPING_SET_TO_GLOBAL_AGGREGATE = + buildConf("spark.sql.analyzer.lowerEmptyGroupingSetToGlobalAggregate.enabled") + .internal() + .version("4.3.0") + .doc( + "When true, a grand-total GROUP BY GROUPING SETS (()) (and the equivalent empty " + + "CUBE() / ROLLUP()) is lowered to a global aggregate during analysis, so it returns " + + "one row over empty input, matching an aggregation with no GROUP BY clause. When false, " + + "falls back to the legacy Expand-based lowering that returns no rows over empty input.") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(true) + val ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS = buildConf("spark.sql.analyzer.uniqueNecessaryMetadataColumns") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index ceb68ab0c92bc..4d51eb2c14a37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -130,12 +130,10 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan, expected) + // CUBE() over no columns is a single empty grouping set, i.e. a grand total, so it lowers + // to a global Aggregate (no Expand) and returns one row even over empty input. val originalPlan2 = Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1) - val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")), - Expand( - Seq(Seq(a, b, c, 0L)), - Seq(a, b, c, gid), - Project(Seq(a, b, c), r1))) + val expected2 = Aggregate(Seq.empty[Expression], Seq(count(c).as("count(c)")), r1) checkAnalysis(originalPlan2, expected2) } @@ -149,12 +147,10 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { Project(Seq(a, b, c, a.as("a"), b.as("b")), r1))) checkAnalysis(originalPlan, expected) + // ROLLUP() over no columns is a single empty grouping set, i.e. a grand total, so it lowers + // to a global Aggregate (no Expand) and returns one row even over empty input. val originalPlan2 = Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1) - val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")), - Expand( - Seq(Seq(a, b, c, 0L)), - Seq(a, b, c, gid), - Project(Seq(a, b, c), r1))) + val expected2 = Aggregate(Seq.empty[Expression], Seq(count(c).as("count(c)")), r1) checkAnalysis(originalPlan2, expected2) } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out index fe6931991e223..1d3bf1a6171b4 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-analytics.sql.out @@ -1931,15 +1931,13 @@ SELECT SUM(b) AS s FROM aggSrc GROUP BY GROUPING SETS (()) -- !query analysis -Aggregate [spark_grouping_id#xL], [sum(b#x) AS s#xL] -+- Expand [[a#x, b#x, 0]], [a#x, b#x, spark_grouping_id#xL] - +- Project [a#x, b#x] - +- SubqueryAlias aggsrc - +- View (`aggSrc`, [a#x, b#x]) - +- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x] - +- Project [a#x, b#x] - +- SubqueryAlias aggSrc - +- LocalRelation [a#x, b#x] +Aggregate [sum(b#x) AS s#xL] ++- SubqueryAlias aggsrc + +- View (`aggSrc`, [a#x, b#x]) + +- Project [cast(a#x as int) AS a#x, cast(b#x as int) AS b#x] + +- Project [a#x, b#x] + +- SubqueryAlias aggSrc + +- LocalRelation [a#x, b#x] -- !query diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out index 2c63fb1525a46..ef70bdd1cc35f 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set.sql.out @@ -219,15 +219,181 @@ SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (()) -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "MISSING_AGGREGATION", + "errorClass" : "MISSING_GROUP_BY", + "sqlState" : "42803", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 48, + "stopIndex" : 74, + "fragment" : "GROUP BY GROUPING SETS (())" + } ] +} + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [count(1) AS c#xL] ++- Filter (k#x > 100) + +- SubqueryAlias t + +- LocalRelation [k#x] + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 +-- !query analysis +Aggregate [count(1) AS c#xL] ++- Filter (k#x > 100) + +- SubqueryAlias t + +- LocalRelation [k#x] + + +-- !query +SELECT count(*) AS c, grouping_id() AS g +FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [count(1) AS c#xL, 0 AS g#xL] ++- Filter (k#x > 100) + +- SubqueryAlias t + +- LocalRelation [k#x] + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [count(1) AS c#xL] ++- SubqueryAlias t + +- LocalRelation [k#x] + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [sum(v#x) AS total#xL, avg(v#x) AS mean#x, max(v#x) AS hi#x, count(1) AS c#xL] ++- Filter (v#x > 100) + +- SubqueryAlias t + +- LocalRelation [v#x] + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 +-- !query analysis +Aggregate [sum(v#x) AS total#xL, avg(v#x) AS mean#x, max(v#x) AS hi#x, count(1) AS c#xL] ++- Filter (v#x > 100) + +- SubqueryAlias t + +- LocalRelation [v#x] + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [sum(v#x) AS total#xL, avg(v#x) AS mean#x, max(v#x) AS hi#x, count(1) AS c#xL] ++- SubqueryAlias t + +- LocalRelation [v#x] + + +-- !query +SELECT v, count(*) AS c FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 GROUP BY v +-- !query analysis +Aggregate [v#x], [v#x, count(1) AS c#xL] ++- Filter (v#x > 100) + +- SubqueryAlias t + +- LocalRelation [v#x] + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) HAVING grouping_id() = 0 +-- !query analysis +Filter (0 = cast(0 as bigint)) ++- Aggregate [count(1) AS c#xL] + +- SubqueryAlias t + +- LocalRelation [k#x] + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) ORDER BY grouping_id() +-- !query analysis +Sort [0 ASC NULLS FIRST], true ++- Aggregate [count(1) AS c#xL] + +- SubqueryAlias t + +- LocalRelation [k#x] + + +-- !query +SELECT grouping(k) FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "GROUPING_COLUMN_MISMATCH", "sqlState" : "42803", "messageParameters" : { - "expression" : "\"c1\"", - "expressionAnyValue" : "\"any_value(c1)\"" + "grouping" : "k#x", + "groupingColumns" : "" } } +-- !query +CREATE TEMPORARY VIEW grouping_grand_total AS + SELECT * FROM VALUES (1, 10), (2, 20), (3, 30) AS t(k, v) +-- !query analysis +CreateViewCommand `grouping_grand_total`, SELECT * FROM VALUES (1, 10), (2, 20), (3, 30) AS t(k, v), false, false, LocalTempView, UNSUPPORTED, true + +- Project [k#x, v#x] + +- SubqueryAlias t + +- LocalRelation [k#x, v#x] + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM grouping_grand_total WHERE k > 100 GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [sum(v#x) AS total#xL, avg(v#x) AS mean#x, max(v#x) AS hi#x, count(1) AS c#xL] ++- Filter (k#x > 100) + +- SubqueryAlias grouping_grand_total + +- View (`grouping_grand_total`, [k#x, v#x]) + +- Project [cast(k#x as int) AS k#x, cast(v#x as int) AS v#x] + +- Project [k#x, v#x] + +- SubqueryAlias t + +- LocalRelation [k#x, v#x] + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM grouping_grand_total GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [sum(v#x) AS total#xL, avg(v#x) AS mean#x, max(v#x) AS hi#x, count(1) AS c#xL] ++- SubqueryAlias grouping_grand_total + +- View (`grouping_grand_total`, [k#x, v#x]) + +- Project [cast(k#x as int) AS k#x, cast(v#x as int) AS v#x] + +- Project [k#x, v#x] + +- SubqueryAlias t + +- LocalRelation [k#x, v#x] + + +-- !query +SELECT count(*) AS c, grouping_id() AS g +FROM grouping_grand_total WHERE k > 100 GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [count(1) AS c#xL, 0 AS g#xL] ++- Filter (k#x > 100) + +- SubqueryAlias grouping_grand_total + +- View (`grouping_grand_total`, [k#x, v#x]) + +- Project [cast(k#x as int) AS k#x, cast(v#x as int) AS v#x] + +- Project [k#x, v#x] + +- SubqueryAlias t + +- LocalRelation [k#x, v#x] + + +-- !query +DROP VIEW grouping_grand_total +-- !query analysis +DropTempViewCommand grouping_grand_total, false + + -- !query SELECT k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1)) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set_grand_total_disabled.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set_grand_total_disabled.sql.out new file mode 100644 index 0000000000000..83b2f5e64a006 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/grouping_set_grand_total_disabled.sql.out @@ -0,0 +1,56 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [spark_grouping_id#xL], [count(1) AS c#xL] ++- Expand [[k#x, 0]], [k#x, spark_grouping_id#xL] + +- Project [k#x] + +- Filter (k#x > 100) + +- SubqueryAlias t + +- LocalRelation [k#x] + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [spark_grouping_id#xL], [sum(v#x) AS total#xL, avg(v#x) AS mean#x, max(v#x) AS hi#x, count(1) AS c#xL] ++- Expand [[v#x, 0]], [v#x, spark_grouping_id#xL] + +- Project [v#x] + +- Filter (v#x > 100) + +- SubqueryAlias t + +- LocalRelation [v#x] + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) +-- !query analysis +Aggregate [spark_grouping_id#xL], [count(1) AS c#xL] ++- Expand [[k#x, 0]], [k#x, spark_grouping_id#xL] + +- Project [k#x] + +- SubqueryAlias t + +- LocalRelation [k#x] + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) HAVING grouping_id() = 0 +-- !query analysis +Project [c#xL] ++- Filter (spark_grouping_id#xL = cast(0 as bigint)) + +- Aggregate [spark_grouping_id#xL], [count(1) AS c#xL, spark_grouping_id#xL] + +- Expand [[k#x, 0]], [k#x, spark_grouping_id#xL] + +- Project [k#x] + +- SubqueryAlias t + +- LocalRelation [k#x] + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) ORDER BY grouping_id() +-- !query analysis +Project [c#xL] ++- Sort [spark_grouping_id#xL ASC NULLS FIRST], true + +- Aggregate [spark_grouping_id#xL], [count(1) AS c#xL, spark_grouping_id#xL] + +- Expand [[k#x, 0]], [k#x, spark_grouping_id#xL] + +- Project [k#x] + +- SubqueryAlias t + +- LocalRelation [k#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql index 013a5419f8d58..386236ef1b6fb 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set.sql @@ -55,6 +55,62 @@ SELECT a, b, c, count(d) FROM grouping GROUP BY WITH CUBE; SELECT c1 FROM (values (1,2), (3,2)) t(c1, c2) GROUP BY GROUPING SETS (()); +-- GROUP BY GROUPING SETS (()) is a grand total and must return one row over empty input, just +-- like an aggregation without a GROUP BY clause. +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 GROUP BY GROUPING SETS (()); + +-- Semantically identical query without GROUP BY, for comparison. +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100; + +-- Grand total over empty input with grouping_id() in the SELECT list. +SELECT count(*) AS c, grouping_id() AS g +FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 GROUP BY GROUPING SETS (()); + +-- Grand total over non-empty input still returns the single aggregated row. +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()); + +-- Meaningful aggregates make the grand-total row explicit. Over empty input it is a single row of +-- NULL measures with count 0 (identical to the same query with no GROUP BY); over non-empty input +-- it carries the real totals. +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 GROUP BY GROUPING SETS (()); + +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100; + +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) GROUP BY GROUPING SETS (()); + +-- Contrast: grouping by a real column over empty input returns no rows (no groups), whereas the +-- grand total above returns one row -- the grand-total-specific semantics this fix restores. +SELECT v, count(*) AS c FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 GROUP BY v; + +-- grouping_id() over a grand total is 0, so it resolves in HAVING and ORDER BY. +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) HAVING grouping_id() = 0; +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) ORDER BY grouping_id(); + +-- grouping() over a grand total references a non-grouping column and is rejected. +SELECT grouping(k) FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()); + +-- Grand total over a named relation rather than inline data in the query: the grand total reads +-- FROM a temporary view. +CREATE TEMPORARY VIEW grouping_grand_total AS + SELECT * FROM VALUES (1, 10), (2, 20), (3, 30) AS t(k, v); + +-- Empty input (filter removes all rows): one grand-total row of NULL measures with count 0. +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM grouping_grand_total WHERE k > 100 GROUP BY GROUPING SETS (()); + +-- Non-empty input: the real totals in a single grand-total row. +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM grouping_grand_total GROUP BY GROUPING SETS (()); + +-- grouping_id() over the grand total folds to 0. +SELECT count(*) AS c, grouping_id() AS g +FROM grouping_grand_total WHERE k > 100 GROUP BY GROUPING SETS (()); + +DROP VIEW grouping_grand_total; + -- duplicate entries in grouping sets SELECT k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1)); diff --git a/sql/core/src/test/resources/sql-tests/inputs/grouping_set_grand_total_disabled.sql b/sql/core/src/test/resources/sql-tests/inputs/grouping_set_grand_total_disabled.sql new file mode 100644 index 0000000000000..70a49518044cd --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/grouping_set_grand_total_disabled.sql @@ -0,0 +1,21 @@ +-- Kill-switch coverage for the grand-total GROUP BY GROUPING SETS (()) lowering. With +-- spark.sql.analyzer.lowerEmptyGroupingSetToGlobalAggregate.enabled = false, a +-- grand-total GROUP BY GROUPING SETS (()) reverts to the legacy Expand-based lowering, which +-- returns no rows over empty input. Runs under both analyzers (dual-run) to lock down parity in +-- the flag-off state. +--SET spark.sql.analyzer.lowerEmptyGroupingSetToGlobalAggregate.enabled=false + +-- Empty input: legacy lowering returns no rows (the pre-fix behavior the kill switch restores). +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 GROUP BY GROUPING SETS (()); + +-- Same meaningful query as in grouping_set.sql: with the fix off this returns no rows; with the +-- fix on it returns one row of NULL measures with count 0. The contrast is the correctness fix. +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 GROUP BY GROUPING SETS (()); + +-- Non-empty input still returns the single aggregated row. +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()); + +-- grouping_id() in HAVING/ORDER BY resolves to the Expand spark_grouping_id (value 0). +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) HAVING grouping_id() = 0; +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) ORDER BY grouping_id(); diff --git a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out index 21e70c63535cb..a062a5a99dca9 100644 --- a/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/grouping_set.sql.out @@ -198,15 +198,162 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "MISSING_AGGREGATION", + "errorClass" : "MISSING_GROUP_BY", + "sqlState" : "42803", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 48, + "stopIndex" : 74, + "fragment" : "GROUP BY GROUPING SETS (())" + } ] +} + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output +0 + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 +-- !query schema +struct +-- !query output +0 + + +-- !query +SELECT count(*) AS c, grouping_id() AS g +FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output +0 0 + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output +3 + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output +NULL NULL NULL 0 + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 +-- !query schema +struct +-- !query output +NULL NULL NULL 0 + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output +60 20.0 30 3 + + +-- !query +SELECT v, count(*) AS c FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 GROUP BY v +-- !query schema +struct +-- !query output + + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) HAVING grouping_id() = 0 +-- !query schema +struct +-- !query output +3 + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) ORDER BY grouping_id() +-- !query schema +struct +-- !query output +3 + + +-- !query +SELECT grouping(k) FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "GROUPING_COLUMN_MISMATCH", "sqlState" : "42803", "messageParameters" : { - "expression" : "\"c1\"", - "expressionAnyValue" : "\"any_value(c1)\"" + "grouping" : "k#x", + "groupingColumns" : "" } } +-- !query +CREATE TEMPORARY VIEW grouping_grand_total AS + SELECT * FROM VALUES (1, 10), (2, 20), (3, 30) AS t(k, v) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM grouping_grand_total WHERE k > 100 GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output +NULL NULL NULL 0 + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM grouping_grand_total GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output +60 20.0 30 3 + + +-- !query +SELECT count(*) AS c, grouping_id() AS g +FROM grouping_grand_total WHERE k > 100 GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output +0 0 + + +-- !query +DROP VIEW grouping_grand_total +-- !query schema +struct<> +-- !query output + + + -- !query SELECT k1, k2, avg(v) FROM (VALUES (1,1,1),(2,2,2)) AS t(k1,k2,v) GROUP BY GROUPING SETS ((k1),(k1,k2),(k2,k1)) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/grouping_set_grand_total_disabled.sql.out b/sql/core/src/test/resources/sql-tests/results/grouping_set_grand_total_disabled.sql.out new file mode 100644 index 0000000000000..325a0ae960c2d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/grouping_set_grand_total_disabled.sql.out @@ -0,0 +1,40 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) WHERE k > 100 GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT sum(v) AS total, avg(v) AS mean, max(v) AS hi, count(*) AS c +FROM VALUES (10), (20), (30) AS t(v) WHERE v > 100 GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) +-- !query schema +struct +-- !query output +3 + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) HAVING grouping_id() = 0 +-- !query schema +struct +-- !query output +3 + + +-- !query +SELECT count(*) AS c FROM VALUES (1), (2), (3) AS t(k) GROUP BY GROUPING SETS (()) ORDER BY grouping_id() +-- !query schema +struct +-- !query output +3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b18734d12e702..97931a007605b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -132,6 +132,17 @@ class DataFrameAggregateSuite extends SharedSparkSession ) } + test("cube()/rollup() with no grouping columns return one grand-total row over empty input") { + // With no grouping columns, cube()/rollup() lower to a global aggregate (the grand total), + // which returns one row even over empty input -- like an aggregation with no GROUP BY clause. + // This is the DataFrame-API surface for the empty CUBE/ROLLUP case (not expressible in SQL). + checkAnswer(spark.range(0).cube().count(), Row(0L)) + checkAnswer(spark.range(0).rollup().count(), Row(0L)) + // Non-empty input still collapses to the single grand-total row. + checkAnswer(spark.range(3).cube().count(), Row(3L)) + checkAnswer(spark.range(3).rollup().count(), Row(3L)) + } + test("rollup") { checkAnswer( courseSales.rollup("course", "year").sum("earnings"),