diff --git a/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd b/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd index 12d66b284b0..fe181b2febc 100644 --- a/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd +++ b/exec/java-exec/src/main/codegen/data/DateIntervalFunc.tdd @@ -22,6 +22,8 @@ {truncInputTypes: ["Date", "TimeStamp", "Time", "Interval", "IntervalDay", "IntervalYear"] }, {truncUnits : ["Second", "Minute", "Hour", "Day", "Month", "Year", "Week", "Quarter", "Decade", "Century", "Millennium" ] }, {timestampDiffUnits : ["Nanosecond", "Microsecond", "Millisecond", "Second", "Minute", "Hour", "Day", "Month", "Year", "Week", "Quarter"] }, + {timestampAddUnits : ["Nanosecond", "Microsecond", "Millisecond", "Second", "Minute", "Hour", "Day", "Month", "Year", "Week", "Quarter"] }, + {timestampAddInputTypes : ["Date", "TimeStamp", "Time"] }, { varCharToDate: [ diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/GroupingFunctions.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/GroupingFunctions.java new file mode 100644 index 00000000000..a2415ab4882 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/GroupingFunctions.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.expr.fn.impl; + +import org.apache.drill.exec.expr.DrillSimpleFunc; +import org.apache.drill.exec.expr.annotations.FunctionTemplate; +import org.apache.drill.exec.expr.annotations.Output; +import org.apache.drill.exec.expr.annotations.Param; +import org.apache.drill.exec.expr.holders.IntHolder; + +/** + * Functions for working with GROUPING SETS, ROLLUP, and CUBE. + * + * Note: These are internal helper functions. The actual GROUPING() and GROUPING_ID() + * SQL functions need special query rewriting to work correctly with GROUPING SETS. + */ +public class GroupingFunctions { + + /** + * GROUPING_ID_INTERNAL - Returns the grouping ID bitmap. + * This is an internal function that will be called with the $g column value. + */ + @FunctionTemplate(name = "grouping_id_internal", + scope = FunctionTemplate.FunctionScope.SIMPLE, + nulls = FunctionTemplate.NullHandling.NULL_IF_NULL) + public static class GroupingIdInternal implements DrillSimpleFunc { + + @Param IntHolder groupingId; + @Output IntHolder out; + + public void setup() { + } + + public void eval() { + out.value = groupingId.value; + } + } + + /** + * GROUPING_INTERNAL - Returns 1 if the specified bit in the grouping ID is set, 0 otherwise. + * This is an internal function that extracts a specific bit from the grouping ID. + * + * @param groupingId The grouping ID bitmap ($g column value) + * @param bitPosition The bit position to check (0-based) + */ + @FunctionTemplate(name = "grouping_internal", + scope = FunctionTemplate.FunctionScope.SIMPLE, + nulls = FunctionTemplate.NullHandling.NULL_IF_NULL) + public static class GroupingInternal implements DrillSimpleFunc { + + @Param IntHolder groupingId; + @Param IntHolder bitPosition; + @Output IntHolder out; + + public void setup() { + } + + public void eval() { + // Extract the bit at bitPosition from groupingId + // Bit is 1 if column is NOT in the grouping set (i.e., it's a grouping NULL) + out.value = (groupingId.value >> bitPosition.value) & 1; + } + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/config/UnionAll.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/config/UnionAll.java index 59b4bfdb094..4bc4fbd51b8 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/config/UnionAll.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/config/UnionAll.java @@ -33,9 +33,22 @@ public class UnionAll extends AbstractMultiple { public static final String OPERATOR_TYPE = "UNION"; + private final boolean isGroupingSetsExpansion; + @JsonCreator - public UnionAll(@JsonProperty("children") List children) { + public UnionAll(@JsonProperty("children") List children, + @JsonProperty("isGroupingSetsExpansion") Boolean isGroupingSetsExpansion) { super(children); + this.isGroupingSetsExpansion = isGroupingSetsExpansion != null ? isGroupingSetsExpansion : false; + } + + public UnionAll(List children) { + this(children, false); + } + + @JsonProperty("isGroupingSetsExpansion") + public boolean isGroupingSetsExpansion() { + return isGroupingSetsExpansion; } @Override @@ -45,7 +58,7 @@ public T accept(PhysicalVisitor physicalVis @Override public PhysicalOperator getNewWithChildren(List children) { - return new UnionAll(children); + return new UnionAll(children, isGroupingSetsExpansion); } @Override diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/union/UnionAllRecordBatch.java b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/union/UnionAllRecordBatch.java index fad14184fa5..784e78ec9e5 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/union/UnionAllRecordBatch.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/union/UnionAllRecordBatch.java @@ -17,13 +17,7 @@ */ package org.apache.drill.exec.physical.impl.union; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.NoSuchElementException; -import java.util.Stack; - +import com.google.common.base.Preconditions; import org.apache.calcite.util.Pair; import org.apache.drill.common.exceptions.DrillRuntimeException; import org.apache.drill.common.expression.ErrorCollector; @@ -59,10 +53,16 @@ import org.apache.drill.exec.vector.FixedWidthVector; import org.apache.drill.exec.vector.SchemaChangeCallBack; import org.apache.drill.exec.vector.ValueVector; -import com.google.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Stack; + public class UnionAllRecordBatch extends AbstractBinaryRecordBatch { private static final Logger logger = LoggerFactory.getLogger(UnionAllRecordBatch.class); @@ -278,10 +278,14 @@ private void inferOutputFieldsBothSide(final BatchSchema leftSchema, final Batch final Iterator leftIter = leftSchema.iterator(); final Iterator rightIter = rightSchema.iterator(); + logger.debug("UnionAll inferring schema: isGroupingSetsExpansion={}", popConfig.isGroupingSetsExpansion()); int index = 1; while (leftIter.hasNext() && rightIter.hasNext()) { MaterializedField leftField = leftIter.next(); MaterializedField rightField = rightIter.next(); + logger.debug("Column {}: left='{}' type={}, right='{}' type={}", + index, leftField.getName(), leftField.getType().getMinorType(), + rightField.getName(), rightField.getType().getMinorType()); if (Types.isSameTypeAndMode(leftField.getType(), rightField.getType())) { MajorType.Builder builder = MajorType.newBuilder() @@ -301,15 +305,7 @@ private void inferOutputFieldsBothSide(final BatchSchema leftSchema, final Batch builder.setMinorType(leftField.getType().getMinorType()); builder = Types.calculateTypePrecisionAndScale(leftField.getType(), rightField.getType(), builder); } else { - TypeProtos.MinorType outputMinorType = TypeCastRules.getLeastRestrictiveType( - leftField.getType().getMinorType(), - rightField.getType().getMinorType() - ); - if (outputMinorType == null) { - throw new DrillRuntimeException("Type mismatch between " + leftField.getType().getMinorType().toString() + - " on the left side and " + rightField.getType().getMinorType().toString() + - " on the right side in column " + index + " of UNION ALL"); - } + TypeProtos.MinorType outputMinorType = resolveUnionColumnType(leftField, rightField, index); builder.setMinorType(outputMinorType); } @@ -328,6 +324,46 @@ private void inferOutputFieldsBothSide(final BatchSchema leftSchema, final Batch "Mismatch of column count should have been detected when validating sqlNode at planning"; } + /** + * Determines the output type for a UNION ALL column when combining two types. + *

+ * Special handling is applied for GROUPING SETS expansion: + * - Drill represents NULL columns as INT during grouping sets expansion. + * - If one side is INT (likely a NULL placeholder) and the other is not, prefer the non-INT type. + *

+ * For all other cases, the least restrictive type according to Drill's type cast rules is returned. + * + * @param leftField The type of the left column + * @param rightField The type of the right column + * @param index The column index (for logging) + * @return The resolved output type + * @throws DrillRuntimeException if types are incompatible + */ + private TypeProtos.MinorType resolveUnionColumnType(MaterializedField leftField, + MaterializedField rightField, + int index) { + TypeProtos.MinorType leftType = leftField.getType().getMinorType(); + TypeProtos.MinorType rightType = rightField.getType().getMinorType(); + + boolean isGroupingSets = popConfig.isGroupingSetsExpansion(); + boolean leftIsPlaceholder = leftType == TypeProtos.MinorType.INT && rightType != TypeProtos.MinorType.INT; + boolean rightIsPlaceholder = rightType == TypeProtos.MinorType.INT && leftType != TypeProtos.MinorType.INT; + + if (isGroupingSets && (leftIsPlaceholder || rightIsPlaceholder)) { + TypeProtos.MinorType outputType = leftIsPlaceholder ? rightType : leftType; + logger.debug("GROUPING SETS: Preferring {} over INT for column {}", outputType, index); + return outputType; + } + + TypeProtos.MinorType outputType = TypeCastRules.getLeastRestrictiveType(leftType, rightType); + if (outputType == null) { + throw new DrillRuntimeException("Type mismatch between " + leftType + + " and " + rightType + " in column " + index + " of UNION ALL"); + } + logger.debug("Using standard type rules: {} + {} -> {}", leftType, rightType, outputType); + return outputType; + } + private void inferOutputFieldsOneSide(final BatchSchema schema) { for (MaterializedField field : schema) { container.addOrGet(field, callBack); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java index 4a32b3a9bdd..a8c5224a234 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java @@ -356,6 +356,11 @@ static RuleSet getDrillUserConfigurableLogicalRules(OptimizerRulesContext optimi Convert from Calcite Logical to Drill Logical Rules. */ RuleInstance.EXPAND_CONVERSION_RULE, + + // Expand GROUPING SETS, ROLLUP, and CUBE BEFORE converting aggregates to Drill logical operators + // This prevents multi-grouping-set aggregates from being converted to DrillAggregateRel + RuleInstance.AGGREGATE_EXPAND_GROUPING_SETS_RULE, + DrillScanRule.INSTANCE, DrillFilterRule.INSTANCE, DrillProjectRule.INSTANCE, diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/RuleInstance.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/RuleInstance.java index baa39dba236..a370c64e76b 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/RuleInstance.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/RuleInstance.java @@ -39,6 +39,7 @@ import org.apache.calcite.rel.rules.SortRemoveRule; import org.apache.calcite.rel.rules.SubQueryRemoveRule; import org.apache.calcite.rel.rules.UnionToDistinctRule; +import org.apache.drill.exec.planner.logical.DrillAggregateExpandGroupingSetsRule; import org.apache.drill.exec.planner.logical.DrillConditions; import org.apache.drill.exec.planner.logical.DrillRelFactories; import com.google.common.base.Preconditions; @@ -107,6 +108,9 @@ public boolean matches(RelOptRuleCall call) { .withRelBuilderFactory(DrillRelFactories.LOGICAL_BUILDER) .toRule(); + RelOptRule AGGREGATE_EXPAND_GROUPING_SETS_RULE = + DrillAggregateExpandGroupingSetsRule.INSTANCE; + /** * Instance of the rule that works on logical joins only, and pushes to the * right. diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java new file mode 100644 index 00000000000..9756a408824 --- /dev/null +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java @@ -0,0 +1,481 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill.exec.planner.logical; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.rel.InvalidRelException; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.logical.LogicalAggregate; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Planner rule that expands GROUPING SETS, ROLLUP, and CUBE into a UNION ALL + * of multiple aggregates, each with a single grouping set. + *

+ * This rule converts: + * SELECT a, b, SUM(c) FROM t GROUP BY GROUPING SETS ((a, b), (a), ()) + *

+ * Into: + * SELECT a, b, SUM(c), 0 AS $g FROM t GROUP BY a, b + * UNION ALL + * SELECT a, null, SUM(c), 1 AS $g FROM t GROUP BY a + * UNION ALL + * SELECT null, null, SUM(c), 3 AS $g FROM t GROUP BY () + *

+ * The $g column is the grouping ID that can be used by GROUPING() and GROUPING_ID() functions. + * Currently, the $g column is generated internally but stripped from the final output. + */ +public class DrillAggregateExpandGroupingSetsRule extends RelOptRule { + + public static final DrillAggregateExpandGroupingSetsRule INSTANCE = + new DrillAggregateExpandGroupingSetsRule(); + public static final String GROUPING_ID_COLUMN_NAME = "$g"; + public static final String GROUP_ID_COLUMN_NAME = "$group_id"; + public static final String EXPRESSION_COLUMN_PLACEHOLDER = "EXPR$"; + + private DrillAggregateExpandGroupingSetsRule() { + super(operand(Aggregate.class, any()), DrillRelFactories.LOGICAL_BUILDER, + "DrillAggregateExpandGroupingSetsRule"); + } + + @Override + public boolean matches(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + return aggregate.getGroupSets().size() > 1 + && (aggregate instanceof DrillAggregateRel || aggregate instanceof LogicalAggregate); + } + + @Override + public void onMatch(RelOptRuleCall call) { + final Aggregate aggregate = call.rel(0); + final RelOptCluster cluster = aggregate.getCluster(); + + GroupingFunctionAnalysis analysis = analyzeGroupingFunctions(aggregate.getAggCallList()); + GroupingSetOrderingResult ordering = sortAndAssignGroupIds(aggregate.getGroupSets()); + + List perGroupAggregates = new ArrayList<>(); + for (int i = 0; i < ordering.sortedGroupSets.size(); i++) { + perGroupAggregates.add( + createAggregateForGroupingSet(call, aggregate, ordering.sortedGroupSets.get(i), + ordering.groupIds.get(i), analysis.regularAggCalls)); + } + + RelNode unionResult = buildUnion(cluster, perGroupAggregates); + RelNode result = buildFinalProject(call, unionResult, aggregate, analysis); + + call.transformTo(result); + } + + /** + * Encapsulates analysis results of aggregate calls to determine + * which are regular aggregates and which are grouping-related + * functions (GROUPING, GROUPING_ID, GROUP_ID). + */ + private static class GroupingFunctionAnalysis { + final boolean hasGroupingFunctions; + final List regularAggCalls; + final List groupingFunctionCalls; + final List groupingFunctionPositions; + + GroupingFunctionAnalysis(List regularAggCalls, + List groupingFunctionCalls, + List groupingFunctionPositions) { + this.hasGroupingFunctions = !groupingFunctionPositions.isEmpty(); + this.regularAggCalls = regularAggCalls; + this.groupingFunctionCalls = groupingFunctionCalls; + this.groupingFunctionPositions = groupingFunctionPositions; + } + } + + /** + * Holds the sorted grouping sets (largest first) and their assigned group IDs. + */ + private static class GroupingSetOrderingResult { + final List sortedGroupSets; + final List groupIds; + GroupingSetOrderingResult(List sortedGroupSets, List groupIds) { + this.sortedGroupSets = sortedGroupSets; + this.groupIds = groupIds; + } + } + + /** + * Analyzes aggregate calls to identify which ones are GROUPING-related functions. + * + * @param aggCalls list of aggregate calls in the original aggregate + * @return structure classifying grouping and non-grouping calls + */ + private GroupingFunctionAnalysis analyzeGroupingFunctions(List aggCalls) { + List regularAggCalls = new ArrayList<>(); + List groupingFunctionCalls = new ArrayList<>(); + List groupingFunctionPositions = new ArrayList<>(); + + for (int i = 0; i < aggCalls.size(); i++) { + AggregateCall aggCall = aggCalls.get(i); + SqlKind kind = aggCall.getAggregation().getKind(); + switch (kind) { + case GROUPING: + case GROUPING_ID: + case GROUP_ID: + groupingFunctionPositions.add(i); + groupingFunctionCalls.add(aggCall); + break; + default: + regularAggCalls.add(aggCall); + } + } + + return new GroupingFunctionAnalysis(regularAggCalls, + groupingFunctionCalls, groupingFunctionPositions); + } + + /** + * Sorts the given grouping sets in descending order of their cardinality + * and assigns group IDs to each grouping set based on their occurrences. + * + * @param groupSets a list of grouping sets represented as ImmutableBitSet instances + * @return a GroupingSetOrderingResult containing the sorted grouping sets and their assigned group IDs + */ + private GroupingSetOrderingResult sortAndAssignGroupIds(List groupSets) { + List sortedGroupSets = new ArrayList<>(groupSets); + sortedGroupSets.sort((a, b) -> Integer.compare(b.cardinality(), a.cardinality())); + + Map groupSetOccurrences = new HashMap<>(); + List groupIds = new ArrayList<>(); + + for (ImmutableBitSet groupSet : sortedGroupSets) { + int groupId = groupSetOccurrences.getOrDefault(groupSet, 0); + groupIds.add(groupId); + groupSetOccurrences.put(groupSet, groupId + 1); + } + + return new GroupingSetOrderingResult(sortedGroupSets, groupIds); + } + + /** + * Creates a new aggregate relational node for a specific grouping set. This method constructs + * the necessary aggregation logic and ensures proper handling of grouping columns, aggregate + * calls, and additional metadata such as grouping and group IDs. + * + * @param call the RelOptRuleCall instance being processed + * @param originalAgg the original aggregate relational expression + * @param groupSet the grouping set to be handled in the new aggregate + * @param groupId the unique identifier associated with this specific grouping set + * @param regularAggCalls the list of regular aggregate calls to be included in the aggregate + * @return a RelNode representing the newly created aggregate for the specified grouping set + */ + private RelNode createAggregateForGroupingSet( + RelOptRuleCall call, + Aggregate originalAgg, + ImmutableBitSet groupSet, + int groupId, + List regularAggCalls) { + + ImmutableBitSet fullGroupSet = originalAgg.getGroupSet(); + RelOptCluster cluster = originalAgg.getCluster(); + RexBuilder rexBuilder = cluster.getRexBuilder(); + RelDataTypeFactory typeFactory = cluster.getTypeFactory(); + RelNode input = originalAgg.getInput(); + + Aggregate newAggregate; + if (originalAgg instanceof DrillAggregateRel) { + newAggregate = new DrillAggregateRel(cluster, originalAgg.getTraitSet(), input, + groupSet, ImmutableList.of(groupSet), regularAggCalls); + } else { + newAggregate = originalAgg.copy(originalAgg.getTraitSet(), input, groupSet, + ImmutableList.of(groupSet), regularAggCalls); + } + + List projects = new ArrayList<>(); + List fieldNames = new ArrayList<>(); + int aggOutputIdx = 0; + int outputColIdx = 0; + + // Populate grouping columns (nulls for omitted columns) + for (int col : fullGroupSet) { + if (groupSet.get(col)) { + projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx++)); + } else { + RelDataType nullType = originalAgg.getRowType().getFieldList().get(outputColIdx).getType(); + projects.add(rexBuilder.makeNullLiteral(nullType)); + } + fieldNames.add(originalAgg.getRowType().getFieldList().get(outputColIdx++).getName()); + } + + // Add regular aggregates + for (AggregateCall regCall : regularAggCalls) { + projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx++)); + fieldNames.add(regCall.getName() != null ? regCall.getName() : "agg$" + aggOutputIdx); + } + + // Add grouping ID ($g) + int groupingId = computeGroupingId(fullGroupSet, groupSet); + projects.add(rexBuilder.makeLiteral(groupingId, + typeFactory.createSqlType(SqlTypeName.INTEGER), true)); + fieldNames.add(GROUPING_ID_COLUMN_NAME); + + // Add group ID ($group_id) + projects.add(rexBuilder.makeLiteral(groupId, + typeFactory.createSqlType(SqlTypeName.INTEGER), true)); + fieldNames.add(GROUP_ID_COLUMN_NAME); + + return call.builder().push(newAggregate).project(projects, fieldNames, false).build(); + } + + private int computeGroupingId(ImmutableBitSet fullGroupSet, ImmutableBitSet groupSet) { + int id = 0; + int bit = 0; + for (int col : fullGroupSet) { + if (!groupSet.get(col)) { + id |= (1 << bit); + } + bit++; + } + return id; + } + + /** + * Builds a union of the given aggregate relational nodes. If there is only one + * aggregate node, it returns that node directly. Otherwise, it constructs a + * union relational expression containing all the provided aggregate nodes. + * + * @param cluster the optimization cluster in which the relational node resides + * @param aggregates a list of aggregate relational nodes to be combined into a union + * @return the resultant union relational node if multiple nodes are provided; + * otherwise, the single aggregate node from the input list + * @throws RuntimeException if union creation fails due to invalid relational state + */ + private RelNode buildUnion(RelOptCluster cluster, List aggregates) { + if (aggregates.size() == 1) { + return aggregates.get(0); + } + try { + List convertedInputs = new ArrayList<>(); + for (RelNode agg : aggregates) { + convertedInputs.add(convert(agg, agg.getTraitSet().plus(DrillRel.DRILL_LOGICAL).simplify())); + } + return new DrillUnionRel(cluster, + cluster.traitSet().plus(DrillRel.DRILL_LOGICAL), + convertedInputs, + true, + true, + true); + } catch (InvalidRelException e) { + throw new RuntimeException("Failed to create DrillUnionRel", e); + } + } + + /** + * Builds the final projection for the result of the aggregation, incorporating the necessary + * output columns such as the grouping functions and the aggregation results. + * + * @param call the RelOptRuleCall instance being processed + * @param unionResult the relational expression resulting from the union of partial aggregates + * @param aggregate the original Aggregate relational expression + * @param analysis the analysis results classifying regular and grouping-related aggregate calls + * @return the relational expression with the final projection applied + */ + private RelNode buildFinalProject( + RelOptRuleCall call, + RelNode unionResult, + Aggregate aggregate, + GroupingFunctionAnalysis analysis) { + + RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); + RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory(); + ImmutableBitSet fullGroupSet = aggregate.getGroupSet(); + List finalProjects = new ArrayList<>(); + List finalFieldNames = new ArrayList<>(); + int numFields = unionResult.getRowType().getFieldCount(); + + for (int i = 0; i < fullGroupSet.cardinality(); i++) { + finalProjects.add(rexBuilder.makeInputRef(unionResult, i)); + finalFieldNames.add(unionResult.getRowType().getFieldList().get(i).getName()); + } + + if (analysis.hasGroupingFunctions) { + RexNode gColumnRef = rexBuilder.makeInputRef(unionResult, numFields - 2); + RexNode groupIdColumnRef = rexBuilder.makeInputRef(unionResult, numFields - 1); + Map groupingFuncMap = new HashMap<>(); + for (int i = 0; i < analysis.groupingFunctionPositions.size(); i++) { + groupingFuncMap.put(analysis.groupingFunctionPositions.get(i), + analysis.groupingFunctionCalls.get(i)); + } + + int regularAggIndex = fullGroupSet.cardinality(); + for (int origPos = 0; origPos < aggregate.getAggCallList().size(); origPos++) { + if (groupingFuncMap.containsKey(origPos)) { + AggregateCall groupingCall = groupingFuncMap.get(origPos); + String funcName = groupingCall.getAggregation().getName(); + if ("GROUPING".equals(funcName)) { + processGrouping(groupingCall, fullGroupSet, rexBuilder, typeFactory, + gColumnRef, finalProjects, finalFieldNames); + } else if ("GROUPING_ID".equals(funcName)) { + processGroupingId(groupingCall, fullGroupSet, rexBuilder, typeFactory, + gColumnRef, finalProjects, finalFieldNames); + } else if ("GROUP_ID".equals(funcName)) { + finalProjects.add(groupIdColumnRef); + String fieldName = groupingCall.getName() != null + ? groupingCall.getName() + : EXPRESSION_COLUMN_PLACEHOLDER + finalFieldNames.size(); + finalFieldNames.add(fieldName); + } + } else { + finalProjects.add(rexBuilder.makeInputRef(unionResult, regularAggIndex)); + finalFieldNames.add(unionResult.getRowType().getFieldList().get(regularAggIndex).getName()); + regularAggIndex++; + } + } + } else { + for (int i = fullGroupSet.cardinality(); i < numFields - 2; i++) { + finalProjects.add(rexBuilder.makeInputRef(unionResult, i)); + finalFieldNames.add(unionResult.getRowType().getFieldList().get(i).getName()); + } + } + + return call.builder().push(unionResult).project(finalProjects, finalFieldNames, false).build(); + } + + /** + * Processes the GROUPING aggregate function by extracting the bit representing + * whether each column is aggregated or not and appends the computed RexNode + * projection and the corresponding field name to the provided lists. + * + * @param groupingCall the GROUPING aggregate function call to process + * @param fullGroupSet the complete set of grouping keys for the aggregation + * @param rexBuilder the RexBuilder instance used to construct RexNode expressions + * @param typeFactory the data type factory used for creating type-specific literals + * @param gColumnRef the RexNode reference to the grouping column + * @param finalProjects the list to store the constructed RexNode projections + * @param finalFieldNames the list to store the corresponding field names + */ + private void processGrouping(AggregateCall groupingCall, + ImmutableBitSet fullGroupSet, + RexBuilder rexBuilder, + RelDataTypeFactory typeFactory, + RexNode gColumnRef, + List finalProjects, + List finalFieldNames) { + + if (groupingCall.getArgList().size() != 1) { + throw new RuntimeException("GROUPING() expects exactly 1 argument"); + } + + int columnIndex = groupingCall.getArgList().get(0); + int bitPosition = 0; + for (int col : fullGroupSet) { + if (col == columnIndex) { + break; + } + bitPosition++; + } + + RexNode divisor = rexBuilder.makeLiteral( + 1 << bitPosition, typeFactory.createSqlType(SqlTypeName.INTEGER), true); + + RexNode divided = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, gColumnRef, divisor); + RexNode extractBit = rexBuilder.makeCall(SqlStdOperatorTable.MOD, divided, + rexBuilder.makeLiteral(2, typeFactory.createSqlType(SqlTypeName.INTEGER), true)); + + finalProjects.add(extractBit); + String fieldName = groupingCall.getName() != null + ? groupingCall.getName() + : "EXPR$" + finalFieldNames.size(); + finalFieldNames.add(fieldName); + } + + /** + * Processes the GROUPING_ID aggregate function by computing a bitmask + * based on the provided grouping columns and full group set. Constructs + * the corresponding RexNode representation for the GROUPING_ID function + * and appends it to the final projection and field names list. + * + * @param groupingCall the GROUPING_ID aggregate function call to process + * @param fullGroupSet the complete set of grouping keys for the aggregation + * @param rexBuilder the RexBuilder instance used to construct RexNode expressions + * @param typeFactory the data type factory for creating type-specific literals + * @param gColumnRef the RexNode reference to the group column + * @param finalProjects the list to which the computed RexNode is added + * @param finalFieldNames the list to which the corresponding field name is added + */ + private void processGroupingId(AggregateCall groupingCall, + ImmutableBitSet fullGroupSet, + RexBuilder rexBuilder, + RelDataTypeFactory typeFactory, + RexNode gColumnRef, + List finalProjects, + List finalFieldNames) { + + if (groupingCall.getArgList().isEmpty()) { + throw new RuntimeException("GROUPING_ID() expects at least one argument"); + } + + RexNode result = null; + for (int i = 0; i < groupingCall.getArgList().size(); i++) { + int columnIndex = groupingCall.getArgList().get(i); + int bitPosition = 0; + for (int col : fullGroupSet) { + if (col == columnIndex) { + break; + } + bitPosition++; + } + + RexNode divisor = rexBuilder.makeLiteral(1 << bitPosition, + typeFactory.createSqlType(SqlTypeName.INTEGER), true); + + RexNode divided = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, gColumnRef, divisor); + RexNode extractBit = rexBuilder.makeCall(SqlStdOperatorTable.MOD, divided, + rexBuilder.makeLiteral(2, typeFactory.createSqlType(SqlTypeName.INTEGER), true)); + + int resultBitPos = groupingCall.getArgList().size() - 1 - i; + RexNode bitInPosition = (resultBitPos > 0) + ? rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, extractBit, + rexBuilder.makeLiteral(1 << resultBitPos, + typeFactory.createSqlType(SqlTypeName.INTEGER), true)) + : extractBit; + + result = (result == null) + ? bitInPosition + : rexBuilder.makeCall(SqlStdOperatorTable.PLUS, result, bitInPosition); + } + + finalProjects.add(result); + String fieldName = groupingCall.getName() != null + ? groupingCall.getName() + : "EXPR$" + finalFieldNames.size(); + finalFieldNames.add(fieldName); + } +} diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateRule.java index 8ce07dd229a..86d67a866ce 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateRule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateRule.java @@ -49,6 +49,12 @@ public void onMatch(RelOptRuleCall call) { return; } + if (aggregate.getGroupSets().size() > 1) { + // Don't convert aggregates with multiple grouping sets (GROUPING SETS/ROLLUP/CUBE) to DrillAggregateRel + // These should be expanded into UNION ALL by DrillAggregateExpandGroupingSetsRule first + return; + } + final RelTraitSet traits = aggregate.getTraitSet().plus(DrillRel.DRILL_LOGICAL); final RelNode convertedInput = convert(input, input.getTraitSet().plus(DrillRel.DRILL_LOGICAL).simplify()); call.transformTo(new DrillAggregateRel(aggregate.getCluster(), traits, convertedInput, diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java index 9cf1b26d45f..349ba2a02f9 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillOptiq.java @@ -917,6 +917,11 @@ public LogicalExpression visitLiteral(RexLiteral literal) { return (ValueExpressions.getIntervalDay(((BigDecimal) (literal.getValue())).longValue())); case NULL: return NullExpression.INSTANCE; + case UNKNOWN: + // UNKNOWN type is used for NULL literals where the type should be inferred later + // This is used by GROUPING SETS expansion where NULL placeholders need type inference + // from the other branch of UNION ALL + return NullExpression.INSTANCE; case ANY: if (isLiteralNull(literal)) { return NullExpression.INSTANCE; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionAllRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionAllRule.java index a35d320bc28..18719b54479 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionAllRule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionAllRule.java @@ -57,9 +57,25 @@ public void onMatch(RelOptRuleCall call) { final RelNode convertedInput = convert(input, input.getTraitSet().plus(DrillRel.DRILL_LOGICAL).simplify()); convertedInputs.add(convertedInput); } + + // Detect if this union is from GROUPING SETS expansion by checking if ANY input has a $g column + // The $g column is the grouping ID that we add during expansion + // Check all inputs because the union tree may be built incrementally (binary tree structure) + boolean isGroupingSetsExpansion = false; + for (RelNode input : convertedInputs) { + org.apache.calcite.rel.type.RelDataType inputType = input.getRowType(); + if (inputType.getFieldCount() > 0) { + String lastFieldName = inputType.getFieldList().get(inputType.getFieldCount() - 1).getName(); + if ("$g".equals(lastFieldName)) { + isGroupingSetsExpansion = true; + break; + } + } + } + try { call.transformTo(new DrillUnionRel(union.getCluster(), traits, convertedInputs, union.all, - true /* check compatibility */)); + true /* check compatibility */, isGroupingSetsExpansion)); } catch (InvalidRelException e) { tracer.warn(e.toString()); } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionRel.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionRel.java index 263266e1ffa..7d01ec110e7 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionRel.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillUnionRel.java @@ -37,21 +37,33 @@ * Union implemented in Drill. */ public class DrillUnionRel extends Union implements DrillRel, DrillSetOpRel { + private final boolean isGroupingSetsExpansion; + /** Creates a DrillUnionRel. */ public DrillUnionRel(RelOptCluster cluster, RelTraitSet traits, List inputs, boolean all, boolean checkCompatibility) throws InvalidRelException { + this(cluster, traits, inputs, all, checkCompatibility, false); + } + + public DrillUnionRel(RelOptCluster cluster, RelTraitSet traits, + List inputs, boolean all, boolean checkCompatibility, boolean isGroupingSetsExpansion) throws InvalidRelException { super(cluster, traits, inputs, all); + this.isGroupingSetsExpansion = isGroupingSetsExpansion; if (checkCompatibility && !this.isCompatible(getRowType(), getInputs())) { throw new InvalidRelException("Input row types of the Union are not compatible."); } } + public boolean isGroupingSetsExpansion() { + return isGroupingSetsExpansion; + } + @Override public DrillUnionRel copy(RelTraitSet traitSet, List inputs, boolean all) { try { return new DrillUnionRel(getCluster(), traitSet, inputs, all, - false /* don't check compatibility during copy */); + false /* don't check compatibility during copy */, isGroupingSetsExpansion); } catch (InvalidRelException e) { throw new AssertionError(e); } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java index a61db99510c..1b3805e5bd1 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/HashAggPrule.java @@ -64,6 +64,12 @@ public void onMatch(RelOptRuleCall call) { return; } + if (aggregate.getGroupSets().size() > 1) { + // Don't use HashAggregate for aggregates with multiple grouping sets (GROUPING SETS/ROLLUP/CUBE) + // These should be expanded into UNION ALL by DrillAggregateExpandGroupingSetsRule first + return; + } + RelTraitSet traits; try { diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java index cbed109d19b..49a4e935399 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java @@ -64,6 +64,12 @@ public void onMatch(RelOptRuleCall call) { return; } + if (aggregate.getGroupSets().size() > 1) { + // Don't use StreamingAggregate for aggregates with multiple grouping sets (GROUPING SETS/ROLLUP/CUBE) + // These should be expanded into UNION ALL by DrillAggregateExpandGroupingSetsRule first + return; + } + try { if (aggregate.getGroupSet().isEmpty()) { DrillDistributionTrait singleDist = DrillDistributionTrait.SINGLETON; diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrel.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrel.java index 460346fa118..a60c7fc2768 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrel.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrel.java @@ -39,17 +39,24 @@ public class UnionAllPrel extends UnionPrel { + private final boolean isGroupingSetsExpansion; + public UnionAllPrel(RelOptCluster cluster, RelTraitSet traits, List inputs) throws InvalidRelException { - super(cluster, traits, inputs, true /* all */); + this(cluster, traits, inputs, false); + } + public UnionAllPrel(RelOptCluster cluster, RelTraitSet traits, List inputs, boolean isGroupingSetsExpansion) + throws InvalidRelException { + super(cluster, traits, inputs, true /* all */); + this.isGroupingSetsExpansion = isGroupingSetsExpansion; } @Override public Union copy(RelTraitSet traitSet, List inputs, boolean all) { try { - return new UnionAllPrel(this.getCluster(), traitSet, inputs); + return new UnionAllPrel(this.getCluster(), traitSet, inputs, isGroupingSetsExpansion); }catch (InvalidRelException e) { throw new AssertionError(e); } @@ -78,7 +85,7 @@ public PhysicalOperator getPhysicalOperator(PhysicalPlanCreator creator) throws inputPops.add( ((Prel)this.getInputs().get(i)).getPhysicalOperator(creator)); } - UnionAll unionall = new UnionAll(inputPops); + UnionAll unionall = new UnionAll(inputPops, isGroupingSetsExpansion); return creator.addMetadata(this, unionall); } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrule.java index 8e094604cf1..40ebe7f1cd4 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/UnionAllPrule.java @@ -101,7 +101,8 @@ public void onMatch(RelOptRuleCall call) { Preconditions.checkArgument(convertedInputList.size() >= 2, "Union list must be at least two items."); RelNode left = convertedInputList.get(0); for (int i = 1; i < convertedInputList.size(); i++) { - left = new UnionAllPrel(union.getCluster(), traits, ImmutableList.of(left, convertedInputList.get(i))); + left = new UnionAllPrel(union.getCluster(), traits, ImmutableList.of(left, convertedInputList.get(i)), + union.isGroupingSetsExpansion()); } call.transformTo(left); diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java index f433308ac24..680e3ca3910 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/sql/parser/UnsupportedOperatorsVisitor.java @@ -17,10 +17,21 @@ */ package org.apache.drill.exec.planner.sql.parser; +import com.google.common.collect.Lists; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlJoin; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlNumericLiteral; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlSelectKeyword; +import org.apache.calcite.sql.SqlWindow; +import org.apache.calcite.sql.fun.SqlCountAggFunction; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.util.SqlBasicVisitor; +import org.apache.calcite.sql.util.SqlShuttle; import org.apache.calcite.util.Litmus; import org.apache.drill.exec.ExecConstants; import org.apache.drill.exec.exception.UnsupportedOperatorCollector; @@ -28,23 +39,8 @@ import org.apache.drill.exec.planner.physical.PlannerSettings; import org.apache.drill.exec.work.foreman.SqlUnsupportedException; -import org.apache.calcite.sql.SqlSelectKeyword; -import org.apache.calcite.sql.SqlIdentifier; -import org.apache.calcite.sql.SqlSelect; -import org.apache.calcite.sql.SqlWindow; -import org.apache.calcite.sql.fun.SqlCountAggFunction; -import org.apache.calcite.sql.SqlCall; -import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.sql.SqlJoin; -import org.apache.calcite.sql.SqlNode; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.sql.util.SqlShuttle; -import org.apache.calcite.sql.SqlDataTypeSpec; - import java.util.List; -import com.google.common.collect.Lists; - public class UnsupportedOperatorsVisitor extends SqlShuttle { private QueryContext context; private static List disabledType = Lists.newArrayList(); @@ -97,10 +93,6 @@ public SqlNode visit(SqlCall sqlCall) { if (sqlCall instanceof SqlSelect) { SqlSelect sqlSelect = (SqlSelect) sqlCall; - checkGrouping((sqlSelect)); - - checkRollupCubeGrpSets(sqlSelect); - for (SqlNode nodeInSelectList : sqlSelect.getSelectList()) { // If the window function is used with an alias, // enter the first operand of AS operator @@ -358,27 +350,6 @@ public SqlNode visit(SqlCall sqlCall) { return sqlCall.getOperator().acceptCall(this, sqlCall); } - private void checkRollupCubeGrpSets(SqlSelect sqlSelect) { - final ExprFinder rollupCubeGrpSetsFinder = new ExprFinder(RollupCubeGrpSets); - sqlSelect.accept(rollupCubeGrpSetsFinder); - if (rollupCubeGrpSetsFinder.find()) { - unsupportedOperatorCollector.setException(SqlUnsupportedException.ExceptionType.FUNCTION, - "Rollup, Cube, Grouping Sets are not supported in GROUP BY clause.\n" + - "See Apache Drill JIRA: DRILL-3962"); - throw new UnsupportedOperationException(); - } - } - - private void checkGrouping(SqlSelect sqlSelect) { - final ExprFinder groupingFinder = new ExprFinder(GroupingID); - sqlSelect.accept(groupingFinder); - if (groupingFinder.find()) { - unsupportedOperatorCollector.setException(SqlUnsupportedException.ExceptionType.FUNCTION, - "Grouping, Grouping_ID, Group_ID are not supported.\n" + - "See Apache Drill JIRA: DRILL-3962"); - throw new UnsupportedOperationException(); - } - } private boolean checkDirExplorers(SqlNode sqlNode) { final ExprFinder dirExplorersFinder = new ExprFinder(DirExplorersCondition); @@ -401,41 +372,6 @@ private interface SqlNodeCondition { boolean test(SqlNode sqlNode); } - /** - * A condition that returns true if SqlNode has rollup, cube, grouping_sets. - * */ - private final SqlNodeCondition RollupCubeGrpSets = new SqlNodeCondition() { - @Override - public boolean test(SqlNode sqlNode) { - if (sqlNode instanceof SqlCall) { - final SqlOperator operator = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(((SqlCall) sqlNode).getOperator()); - if (operator == SqlStdOperatorTable.ROLLUP - || operator == SqlStdOperatorTable.CUBE - || operator == SqlStdOperatorTable.GROUPING_SETS) { - return true; - } - } - return false; - } - }; - - /** - * A condition that returns true if SqlNode has Grouping, Grouping_ID, GROUP_ID. - */ - private final SqlNodeCondition GroupingID = new SqlNodeCondition() { - @Override - public boolean test(SqlNode sqlNode) { - if (sqlNode instanceof SqlCall) { - final SqlOperator operator = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(((SqlCall) sqlNode).getOperator()); - if (operator == SqlStdOperatorTable.GROUPING - || operator == SqlStdOperatorTable.GROUPING_ID - || operator == SqlStdOperatorTable.GROUP_ID) { - return true; - } - } - return false; - } - }; /** * A condition that returns true if SqlNode has Directory Explorers. diff --git a/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java b/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java index b1649fc5350..ce7a79c1319 100644 --- a/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java +++ b/exec/java-exec/src/test/java/org/apache/drill/TestDisabledFunctionality.java @@ -245,75 +245,4 @@ public void testDisableDecimalFromParquet() throws Exception { resetSessionOption(PlannerSettings.ENABLE_DECIMAL_DATA_TYPE_KEY); } } - - @Test (expected = UnsupportedFunctionException.class) //DRILL-3802 - public void testDisableRollup() throws Exception{ - try { - test("select n_regionkey, count(*) as cnt from cp.`tpch/nation.parquet` group by rollup(n_regionkey, n_name)"); - } catch(UserException ex) { - throwAsUnsupportedException(ex); - throw ex; - } - } - - @Test (expected = UnsupportedFunctionException.class) //DRILL-3802 - public void testDisableCube() throws Exception{ - try { - test("select n_regionkey, count(*) as cnt from cp.`tpch/nation.parquet` group by cube(n_regionkey, n_name)"); - } catch(UserException ex) { - throwAsUnsupportedException(ex); - throw ex; - } - } - - @Test (expected = UnsupportedFunctionException.class) //DRILL-3802 - public void testDisableGroupingSets() throws Exception{ - try { - test("select n_regionkey, count(*) as cnt from cp.`tpch/nation.parquet` group by grouping sets(n_regionkey, n_name)"); - } catch(UserException ex) { - throwAsUnsupportedException(ex); - throw ex; - } - } - - @Test (expected = UnsupportedFunctionException.class) //DRILL-3802 - public void testDisableGrouping() throws Exception{ - try { - test("select n_regionkey, count(*), GROUPING(n_regionkey) from cp.`tpch/nation.parquet` group by n_regionkey;"); - } catch(UserException ex) { - throwAsUnsupportedException(ex); - throw ex; - } - } - - @Test (expected = UnsupportedFunctionException.class) //DRILL-3802 - public void testDisableGrouping_ID() throws Exception{ - try { - test("select n_regionkey, count(*), GROUPING_ID(n_regionkey) from cp.`tpch/nation.parquet` group by n_regionkey;"); - } catch(UserException ex) { - throwAsUnsupportedException(ex); - throw ex; - } - } - - @Test (expected = UnsupportedFunctionException.class) //DRILL-3802 - public void testDisableGroup_ID() throws Exception{ - try { - test("select n_regionkey, count(*), GROUP_ID() from cp.`tpch/nation.parquet` group by n_regionkey;"); - } catch(UserException ex) { - throwAsUnsupportedException(ex); - throw ex; - } - } - - @Test (expected = UnsupportedFunctionException.class) //DRILL-3802 - public void testDisableGroupingInFilter() throws Exception{ - try { - test("select n_regionkey, count(*) from cp.`tpch/nation.parquet` group by n_regionkey HAVING GROUPING(n_regionkey) = 1"); - } catch(UserException ex) { - throwAsUnsupportedException(ex); - throw ex; - } - } - } diff --git a/exec/java-exec/src/test/java/org/apache/drill/TestGroupingSetsResults.java b/exec/java-exec/src/test/java/org/apache/drill/TestGroupingSetsResults.java new file mode 100644 index 00000000000..2c05e212590 --- /dev/null +++ b/exec/java-exec/src/test/java/org/apache/drill/TestGroupingSetsResults.java @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.drill; + +import org.apache.drill.categories.OperatorTest; +import org.apache.drill.categories.SqlTest; +import org.apache.drill.test.ClusterFixture; +import org.apache.drill.test.ClusterFixtureBuilder; +import org.apache.drill.test.ClusterTest; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +@Category({SqlTest.class, OperatorTest.class}) +public class TestGroupingSetsResults extends ClusterTest { + + @BeforeClass + public static void setUp() throws Exception { + ClusterFixtureBuilder builder = ClusterFixture.builder(dirTestWatcher); + startCluster(builder); + } + + @Test + public void testSimpleGroupingSetsResults() throws Exception { + String query = "select n_regionkey, count(*) as cnt " + + "from cp.`tpch/nation.parquet` " + + "group by grouping sets ((n_regionkey), ())"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("n_regionkey", "cnt") + .baselineValues(0, 5L) + .baselineValues(1, 5L) + .baselineValues(2, 5L) + .baselineValues(3, 5L) + .baselineValues(4, 5L) + .baselineValues(null, 25L) // Grand total + .go(); + } + + @Test + public void testRollupResults() throws Exception { + // ROLLUP(a, b) creates grouping sets: (a, b), (a), () + String query = "select n_regionkey, count(*) as cnt " + + "from cp.`tpch/nation.parquet` " + + "where n_regionkey < 2 " + + "group by rollup(n_regionkey)"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("n_regionkey", "cnt") + .baselineValues(0, 5L) // Region 0 + .baselineValues(1, 5L) // Region 1 + .baselineValues(null, 10L) // Grand total + .go(); + } + + @Test + public void testCubeResults() throws Exception { + // CUBE(a) creates grouping sets: (a), () + String query = "select n_regionkey, count(*) as cnt " + + "from cp.`tpch/nation.parquet` " + + "where n_regionkey < 2 " + + "group by cube(n_regionkey)"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("n_regionkey", "cnt") + .baselineValues(0, 5L) // Region 0 + .baselineValues(1, 5L) // Region 1 + .baselineValues(null, 10L) // Grand total + .go(); + } + + @Test + public void testMultiColumnGroupingSets() throws Exception { + // Test GROUPING SETS with two columns + String query = "select n_regionkey, n_nationkey, count(*) as cnt " + + "from cp.`tpch/nation.parquet` " + + "where n_regionkey = 0 and n_nationkey in (0, 5) " + + "group by grouping sets ((n_regionkey, n_nationkey), (n_regionkey), ())"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("n_regionkey", "n_nationkey", "cnt") + // Grouping set (n_regionkey, n_nationkey) + .baselineValues(0, 0, 1L) // Region 0, nation 0 + .baselineValues(0, 5, 1L) // Region 0, nation 5 + // Grouping set (n_regionkey) + .baselineValues(0, null, 2L) // Region 0 total + // Grouping set () + .baselineValues(null, null, 2L) // Grand total + .go(); + } + + @Test + public void testRollupTwoColumns() throws Exception { + // ROLLUP(a, b) creates grouping sets: (a, b), (a), () + String query = "select n_regionkey, n_nationkey, count(*) as cnt " + + "from cp.`tpch/nation.parquet` " + + "where n_regionkey = 0 and n_nationkey in (0, 5) " + + "group by rollup(n_regionkey, n_nationkey)"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("n_regionkey", "n_nationkey", "cnt") + // Grouping set (n_regionkey, n_nationkey) + .baselineValues(0, 0, 1L) // Region 0, nation 0 + .baselineValues(0, 5, 1L) // Region 0, nation 5 + // Grouping set (n_regionkey) + .baselineValues(0, null, 2L) // Region 0 subtotal + // Grouping set () + .baselineValues(null, null, 2L) // Grand total + .go(); + } + + @Test + public void testCubeTwoColumns() throws Exception { + // CUBE(a, b) creates grouping sets: (a, b), (a), (b), () + // Using specific nations to make the test deterministic + String query = "select n_regionkey, n_nationkey, count(*) as cnt " + + "from cp.`tpch/nation.parquet` " + + "where (n_regionkey = 0 and n_nationkey in (0, 5)) " + + " or (n_regionkey = 1 and n_nationkey in (1, 2)) " + + "group by cube(n_regionkey, n_nationkey)"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("n_regionkey", "n_nationkey", "cnt") + // Grouping set (n_regionkey, n_nationkey) + .baselineValues(0, 0, 1L) // Region 0, nation 0 + .baselineValues(0, 5, 1L) // Region 0, nation 5 + .baselineValues(1, 1, 1L) // Region 1, nation 1 + .baselineValues(1, 2, 1L) // Region 1, nation 2 + // Grouping set (n_regionkey) + .baselineValues(0, null, 2L) // Region 0 total + .baselineValues(1, null, 2L) // Region 1 total + // Grouping set (n_nationkey) + .baselineValues(null, 0, 1L) // Nation 0 across all regions + .baselineValues(null, 1, 1L) // Nation 1 across all regions + .baselineValues(null, 2, 1L) // Nation 2 across all regions + .baselineValues(null, 5, 1L) // Nation 5 across all regions + // Grouping set () + .baselineValues(null, null, 4L) // Grand total + .go(); + } + + @Test + public void testGroupingSetsWithAggregates() throws Exception { + // Test multiple aggregate functions with GROUPING SETS + String query = "select n_regionkey, " + + "count(*) as cnt, " + + "min(n_nationkey) as min_key, " + + "max(n_nationkey) as max_key " + + "from cp.`tpch/nation.parquet` " + + "where n_regionkey < 2 " + + "group by grouping sets ((n_regionkey), ())"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("n_regionkey", "cnt", "min_key", "max_key") + .baselineValues(0, 5L, 0, 16) // Region 0 + .baselineValues(1, 5L, 1, 24) // Region 1 + .baselineValues(null, 10L, 0, 24) // Grand total + .go(); + } + + @Test + public void testGroupingSetsEmptyGroupingSet() throws Exception { + // Test just the empty grouping set (grand total only) + String query = "select count(*) as cnt, sum(n_nationkey) as sum_key " + + "from cp.`tpch/nation.parquet` " + + "group by grouping sets (())"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("cnt", "sum_key") + .baselineValues(25L, 300L) // Grand total: 25 nations, sum 0+1+2+...+24 = 300 + .go(); + } + + @Test + public void testGroupingSetsWithWhere() throws Exception { + // Test GROUPING SETS with WHERE clause + String query = "select n_regionkey, count(*) as cnt " + + "from cp.`tpch/nation.parquet` " + + "where n_regionkey in (0, 1, 2) " + + "group by grouping sets ((n_regionkey), ())"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("n_regionkey", "cnt") + .baselineValues(0, 5L) + .baselineValues(1, 5L) + .baselineValues(2, 5L) + .baselineValues(null, 15L) // Total of regions 0, 1, 2 + .go(); + } + + @Test + public void testGroupingSetsWithExpression() throws Exception { + // Test GROUPING SETS with computed columns + String query = "select n_regionkey, " + + "case when n_nationkey < 10 then 'low' else 'high' end as key_range, " + + "count(*) as cnt " + + "from cp.`tpch/nation.parquet` " + + "where n_regionkey < 2 " + + "group by grouping sets (" + + " (n_regionkey, case when n_nationkey < 10 then 'low' else 'high' end), " + + " (n_regionkey)" + + ")"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("n_regionkey", "key_range", "cnt") + // Grouping set (n_regionkey, key_range) + .baselineValues(0, "low", 2L) // Region 0, low keys (0,5) + .baselineValues(0, "high", 3L) // Region 0, high keys (14,15,16) + .baselineValues(1, "low", 3L) // Region 1, low keys (1,2,3) + .baselineValues(1, "high", 2L) // Region 1, high keys (17,24) + // Grouping set (n_regionkey) + .baselineValues(0, null, 5L) // Region 0 total + .baselineValues(1, null, 5L) // Region 1 total + .go(); + } + + @Test + public void testRollupWithJSON() throws Exception { + // Test ROLLUP with JSON data + String query = "select education_level, count(*) as cnt " + + "from cp.`employee.json` " + + "where education_level in ('Graduate Degree', 'Bachelors Degree', 'Partial College') " + + "group by rollup(education_level)"; + + // This should now work with proper type handling + queryBuilder() + .sql(query) + .run(); + } + + // Tests for GROUPING() and GROUPING_ID() functions + // These functions help distinguish between NULL values that are actual data + // versus NULL values inserted by GROUPING SETS/ROLLUP/CUBE operations. + + @Test + public void testGroupingFunction() throws Exception { + // Test GROUPING function with ROLLUP + // GROUPING returns 1 if the column is aggregated (NULL in output), 0 otherwise + String query = "select education_level, " + + "GROUPING(education_level) as grp, " + + "count(*) as cnt " + + "from cp.`employee.json` " + + "where education_level in ('Graduate Degree', 'Bachelors Degree') " + + "group by rollup(education_level)"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("education_level", "grp", "cnt") + .baselineValues("Graduate Degree", 0, 170L) // Not aggregated: grp=0 + .baselineValues("Bachelors Degree", 0, 287L) // Not aggregated: grp=0 + .baselineValues(null, 1, 457L) // Aggregated (grand total): grp=1 + .go(); + } + + @Test + public void testGroupingIdFunction() throws Exception { + // Test GROUPING_ID function with CUBE + // GROUPING_ID returns a bitmap where bit i is 1 if column i is aggregated + // For CUBE(marital_status, education_level), we get grouping sets: + // (marital_status, education_level), (marital_status), (education_level), () + String query = "select marital_status, education_level, " + + "GROUPING_ID(marital_status, education_level) as grp_id, " + + "count(*) as cnt " + + "from cp.`employee.json` " + + "where marital_status in ('S', 'M') " + + "and education_level in ('Graduate Degree', 'Bachelors Degree') " + + "group by cube(marital_status, education_level)"; + + testBuilder() + .sqlQuery(query) + .unOrdered() + .baselineColumns("marital_status", "education_level", "grp_id", "cnt") + // (marital_status, education_level) - neither aggregated: grp_id = 0 + .baselineValues("S", "Graduate Degree", 0, 85L) + .baselineValues("S", "Bachelors Degree", 0, 143L) + .baselineValues("M", "Graduate Degree", 0, 85L) + .baselineValues("M", "Bachelors Degree", 0, 144L) + // (marital_status) - education_level aggregated: grp_id = 1 (bit 0 set) + .baselineValues("S", null, 1, 228L) + .baselineValues("M", null, 1, 229L) + // (education_level) - marital_status aggregated: grp_id = 2 (bit 1 set) + .baselineValues(null, "Graduate Degree", 2, 170L) + .baselineValues(null, "Bachelors Degree", 2, 287L) + // () - both aggregated: grp_id = 3 (both bits set) + .baselineValues(null, null, 3, 457L) + .go(); + } + + @Test + public void testGroupIdFunction() throws Exception { + // Test GROUP_ID function with duplicate grouping sets + // GROUP_ID() returns 0 for first occurrence, 1 for second, etc. + String query = "select n_regionkey, " + + "GROUP_ID() as grp_id, " + + "count(*) as cnt " + + "from cp.`tpch/nation.parquet` " + + "where n_regionkey < 2 " + + "group by grouping sets ((n_regionkey), (n_regionkey), ()) " + + "order by grp_id, n_regionkey nulls last"; + + testBuilder() + .sqlQuery(query) + .ordered() + .baselineColumns("n_regionkey", "grp_id", "cnt") + // First occurrence of (n_regionkey): grp_id = 0 + .baselineValues(0, 0L, 5L) // Region 0 + .baselineValues(1, 0L, 5L) // Region 1 + .baselineValues(null, 0L, 10L) // Empty grouping set + // Second occurrence of (n_regionkey): grp_id = 1 + .baselineValues(0, 1L, 5L) // Region 0 + .baselineValues(1, 1L, 5L) // Region 1 + .go(); + } + + @Test + public void testGroupIdNoDuplicates() throws Exception { + // Test GROUP_ID when there are no duplicate grouping sets + // All GROUP_ID values should be 0 + String query = "select n_regionkey, " + + "GROUP_ID() as grp_id, " + + "count(*) as cnt " + + "from cp.`tpch/nation.parquet` " + + "where n_regionkey < 2 " + + "group by grouping sets ((n_regionkey), ()) " + + "order by n_regionkey nulls last"; + + testBuilder() + .sqlQuery(query) + .ordered() + .baselineColumns("n_regionkey", "grp_id", "cnt") + .baselineValues(0, 0L, 5L) + .baselineValues(1, 0L, 5L) + .baselineValues(null, 0L, 10L) + .go(); + } +}