From 874dff82ed8dd3e8922c763a3c5437b72dfdc9c1 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Fri, 5 Jun 2026 11:56:02 -0700 Subject: [PATCH] fix(calcite): Preserve FILTER(WHERE) on SQL aggregates CalciteAggCallVisitor ignored AggregateFunction.condition, so COUNT(*) FILTER(WHERE age > 30) degraded to COUNT(). Apply the predicate via AggCall.filter() (wrapped in IS TRUE), and retain its input columns in the pre-aggregation trimming projection so they re-resolve. PPL is unaffected (it never sets condition). Signed-off-by: Chen Dai --- .../sql/api/UnifiedQueryPlannerSqlV2Test.java | 41 +++++++++++++++++++ .../sql/calcite/CalciteAggCallVisitor.java | 16 ++++++-- .../sql/calcite/CalciteRelNodeVisitor.java | 16 ++++++++ 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java b/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java index fb38f184925..c2e714a0cb4 100644 --- a/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java +++ b/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java @@ -346,6 +346,47 @@ GROUP BY department HAVING MAX(age) > 30 """); } + @Test + public void testCountStarWithFilter() { + givenQuery("SELECT COUNT(*) FILTER(WHERE age > 30) FROM catalog.employees") + .assertPlan( + """ + LogicalAggregate(group=[{}], COUNT(*) FILTER(WHERE age > 30)=[COUNT() FILTER $0]) + LogicalProject($f1=[>($2, 30)]) + LogicalTableScan(table=[[catalog, employees]]) + """); + } + + @Test + public void testFilteredAggregateWithGroupBy() { + givenQuery( + """ + SELECT department, SUM(age) FILTER(WHERE age > 30) FROM catalog.employees + GROUP BY department + """) + .assertPlan( + """ + LogicalAggregate(group=[{0}], SUM(age) FILTER(WHERE age > 30)=[SUM($1) FILTER $2]) + LogicalProject(department=[$3], age=[$2], $f3=[>($2, 30)]) + LogicalTableScan(table=[[catalog, employees]]) + """); + } + + @Test + public void testMultipleFilteredAggregates() { + givenQuery( + """ + SELECT MAX(age) FILTER(WHERE age > 30), MIN(age) FILTER(WHERE age < 50) + FROM catalog.employees + """) + .assertPlan( + """ + LogicalAggregate(group=[{}], MAX(age) FILTER(WHERE age > 30)=[MAX($0) FILTER $1], MIN(age) FILTER(WHERE age < 50)=[MIN($0) FILTER $2]) + LogicalProject(age=[$2], $f4=[>($2, 30)], $f5=[<($2, 50)]) + LogicalTableScan(table=[[catalog, employees]]) + """); + } + @Test public void testScalarFnOverAggregate() { givenQuery("SELECT ABS(MAX(age)) FROM catalog.employees") diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteAggCallVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteAggCallVisitor.java index 0512316628c..e7a5a5a68bd 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteAggCallVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteAggCallVisitor.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.List; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.tools.RelBuilder.AggCall; import org.apache.logging.log4j.util.Strings; import org.opensearch.sql.ast.AbstractNodeVisitor; @@ -46,10 +47,17 @@ public AggCall visitAggregateFunction(AggregateFunction node, CalcitePlanContext } return BuiltinFunctionName.ofAggregation(node.getFuncName()) .map( - functionName -> { - return PlanUtils.makeAggCall( - context, functionName, node.getDistinct(), field, argList); - }) + functionName -> + PlanUtils.makeAggCall(context, functionName, node.getDistinct(), field, argList)) + // Apply the optional FILTER(WHERE ...) predicate; IS TRUE treats NULL as non-matching. + .map( + aggCall -> + node.condition() == null + ? aggCall + : aggCall.filter( + context.rexBuilder.makeCall( + SqlStdOperatorTable.IS_TRUE, + rexNodeVisitor.analyze(node.condition(), context)))) .orElseThrow( () -> new UnsupportedOperationException("Unexpected aggregation: " + node.getFuncName())); diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index 91a30361a20..ff2e9fbd4bd 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -1563,6 +1563,7 @@ private Pair, List> aggregateWithTrimming( List aggCallRefs = PlanUtils.getInputRefsFromAggCall(resolvedAggCallList); boolean hintNestedAgg = containsNestedAggregator(context.relBuilder, aggCallRefs); trimmedRefs.addAll(aggCallRefs); + trimmedRefs.addAll(getAggCallFilterRefs(aggExprList, context)); context.relBuilder.project(trimmedRefs); // Re-resolve all attributes based on adding trimmed Project. @@ -1795,6 +1796,21 @@ private static AggregateFunction extractAggregateFunction(UnresolvedExpression e return null; } + /** + * Collects input refs used by aggregate FILTER(WHERE ...) predicates so trimming retains them. + */ + private List getAggCallFilterRefs( + List aggExprList, CalcitePlanContext context) { + List refs = new ArrayList<>(); + for (UnresolvedExpression aggExpr : aggExprList) { + AggregateFunction aggFunc = extractAggregateFunction(aggExpr); + if (aggFunc != null && aggFunc.condition() != null) { + refs.addAll(PlanUtils.getInputRefs(rexVisitor.analyze(aggFunc.condition(), context))); + } + } + return refs; + } + private Optional getTimeSpanField(UnresolvedExpression expr) { if (Objects.isNull(expr)) return Optional.empty(); if (expr instanceof Span span && SpanUnit.isTimeUnit(span.getUnit())) {