Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,47 @@ GROUP BY department HAVING MAX(age) > 30
""");
}

@Test
public void testCountStarWithFilter() {
givenQuery("SELECT COUNT(*) FILTER(WHERE age > 30) FROM catalog.employees")
.assertPlan(
"""
LogicalAggregate(group=[{}], COUNT(*) FILTER(WHERE age > 30)=[COUNT() FILTER $0])
LogicalProject($f1=[>($2, 30)])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testFilteredAggregateWithGroupBy() {
givenQuery(
"""
SELECT department, SUM(age) FILTER(WHERE age > 30) FROM catalog.employees
GROUP BY department
""")
.assertPlan(
"""
LogicalAggregate(group=[{0}], SUM(age) FILTER(WHERE age > 30)=[SUM($1) FILTER $2])
LogicalProject(department=[$3], age=[$2], $f3=[>($2, 30)])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testMultipleFilteredAggregates() {
givenQuery(
"""
SELECT MAX(age) FILTER(WHERE age > 30), MIN(age) FILTER(WHERE age < 50)
FROM catalog.employees
""")
.assertPlan(
"""
LogicalAggregate(group=[{}], MAX(age) FILTER(WHERE age > 30)=[MAX($0) FILTER $1], MIN(age) FILTER(WHERE age < 50)=[MIN($0) FILTER $2])
LogicalProject(age=[$2], $f4=[>($2, 30)], $f5=[<($2, 50)])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testScalarFnOverAggregate() {
givenQuery("SELECT ABS(MAX(age)) FROM catalog.employees")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder.AggCall;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.sql.ast.AbstractNodeVisitor;
Expand Down Expand Up @@ -46,10 +47,17 @@ public AggCall visitAggregateFunction(AggregateFunction node, CalcitePlanContext
}
return BuiltinFunctionName.ofAggregation(node.getFuncName())
.map(
functionName -> {
return PlanUtils.makeAggCall(
context, functionName, node.getDistinct(), field, argList);
})
functionName ->
PlanUtils.makeAggCall(context, functionName, node.getDistinct(), field, argList))
// Apply the optional FILTER(WHERE ...) predicate; IS TRUE treats NULL as non-matching.
.map(
aggCall ->
node.condition() == null
? aggCall
: aggCall.filter(
context.rexBuilder.makeCall(
SqlStdOperatorTable.IS_TRUE,
rexNodeVisitor.analyze(node.condition(), context))))
.orElseThrow(
() ->
new UnsupportedOperationException("Unexpected aggregation: " + node.getFuncName()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,7 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
List<RexInputRef> aggCallRefs = PlanUtils.getInputRefsFromAggCall(resolvedAggCallList);
boolean hintNestedAgg = containsNestedAggregator(context.relBuilder, aggCallRefs);
trimmedRefs.addAll(aggCallRefs);
trimmedRefs.addAll(getAggCallFilterRefs(aggExprList, context));
context.relBuilder.project(trimmedRefs);

// Re-resolve all attributes based on adding trimmed Project.
Expand Down Expand Up @@ -1795,6 +1796,21 @@ private static AggregateFunction extractAggregateFunction(UnresolvedExpression e
return null;
}

/**
* Collects input refs used by aggregate FILTER(WHERE ...) predicates so trimming retains them.
*/
private List<RexInputRef> getAggCallFilterRefs(
List<UnresolvedExpression> aggExprList, CalcitePlanContext context) {
List<RexInputRef> refs = new ArrayList<>();
for (UnresolvedExpression aggExpr : aggExprList) {
AggregateFunction aggFunc = extractAggregateFunction(aggExpr);
if (aggFunc != null && aggFunc.condition() != null) {
refs.addAll(PlanUtils.getInputRefs(rexVisitor.analyze(aggFunc.condition(), context)));
}
}
return refs;
}

private Optional<UnresolvedExpression> getTimeSpanField(UnresolvedExpression expr) {
if (Objects.isNull(expr)) return Optional.empty();
if (expr instanceof Span span && SpanUnit.isTimeUnit(span.getUnit())) {
Expand Down
Loading