Skip to content

Commit f12e4c3

Browse files
authored
[BugFix] Fix FILTER(WHERE) dropped on aggregates in unified SQL path (#5523)
Signed-off-by: Chen Dai <daichen@amazon.com>
1 parent f2acb89 commit f12e4c3

3 files changed

Lines changed: 69 additions & 4 deletions

File tree

api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,47 @@ GROUP BY department HAVING MAX(age) > 30
346346
""");
347347
}
348348

349+
@Test
350+
public void testCountStarWithFilter() {
351+
givenQuery("SELECT COUNT(*) FILTER(WHERE age > 30) FROM catalog.employees")
352+
.assertPlan(
353+
"""
354+
LogicalAggregate(group=[{}], COUNT(*) FILTER(WHERE age > 30)=[COUNT() FILTER $0])
355+
LogicalProject($f1=[>($2, 30)])
356+
LogicalTableScan(table=[[catalog, employees]])
357+
""");
358+
}
359+
360+
@Test
361+
public void testFilteredAggregateWithGroupBy() {
362+
givenQuery(
363+
"""
364+
SELECT department, SUM(age) FILTER(WHERE age > 30) FROM catalog.employees
365+
GROUP BY department
366+
""")
367+
.assertPlan(
368+
"""
369+
LogicalAggregate(group=[{0}], SUM(age) FILTER(WHERE age > 30)=[SUM($1) FILTER $2])
370+
LogicalProject(department=[$3], age=[$2], $f3=[>($2, 30)])
371+
LogicalTableScan(table=[[catalog, employees]])
372+
""");
373+
}
374+
375+
@Test
376+
public void testMultipleFilteredAggregates() {
377+
givenQuery(
378+
"""
379+
SELECT MAX(age) FILTER(WHERE age > 30), MIN(age) FILTER(WHERE age < 50)
380+
FROM catalog.employees
381+
""")
382+
.assertPlan(
383+
"""
384+
LogicalAggregate(group=[{}], MAX(age) FILTER(WHERE age > 30)=[MAX($0) FILTER $1], MIN(age) FILTER(WHERE age < 50)=[MIN($0) FILTER $2])
385+
LogicalProject(age=[$2], $f4=[>($2, 30)], $f5=[<($2, 50)])
386+
LogicalTableScan(table=[[catalog, employees]])
387+
""");
388+
}
389+
349390
@Test
350391
public void testScalarFnOverAggregate() {
351392
givenQuery("SELECT ABS(MAX(age)) FROM catalog.employees")

core/src/main/java/org/opensearch/sql/calcite/CalciteAggCallVisitor.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.ArrayList;
99
import java.util.List;
1010
import org.apache.calcite.rex.RexNode;
11+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
1112
import org.apache.calcite.tools.RelBuilder.AggCall;
1213
import org.apache.logging.log4j.util.Strings;
1314
import org.opensearch.sql.ast.AbstractNodeVisitor;
@@ -46,10 +47,17 @@ public AggCall visitAggregateFunction(AggregateFunction node, CalcitePlanContext
4647
}
4748
return BuiltinFunctionName.ofAggregation(node.getFuncName())
4849
.map(
49-
functionName -> {
50-
return PlanUtils.makeAggCall(
51-
context, functionName, node.getDistinct(), field, argList);
52-
})
50+
functionName ->
51+
PlanUtils.makeAggCall(context, functionName, node.getDistinct(), field, argList))
52+
// Apply the optional FILTER(WHERE ...) predicate; IS TRUE treats NULL as non-matching.
53+
.map(
54+
aggCall ->
55+
node.condition() == null
56+
? aggCall
57+
: aggCall.filter(
58+
context.rexBuilder.makeCall(
59+
SqlStdOperatorTable.IS_TRUE,
60+
rexNodeVisitor.analyze(node.condition(), context))))
5361
.orElseThrow(
5462
() ->
5563
new UnsupportedOperationException("Unexpected aggregation: " + node.getFuncName()));

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,7 @@ private Pair<List<RexNode>, List<AggCall>> aggregateWithTrimming(
15631563
List<RexInputRef> aggCallRefs = PlanUtils.getInputRefsFromAggCall(resolvedAggCallList);
15641564
boolean hintNestedAgg = containsNestedAggregator(context.relBuilder, aggCallRefs);
15651565
trimmedRefs.addAll(aggCallRefs);
1566+
trimmedRefs.addAll(getAggCallFilterRefs(aggExprList, context));
15661567
context.relBuilder.project(trimmedRefs);
15671568

15681569
// Re-resolve all attributes based on adding trimmed Project.
@@ -1795,6 +1796,21 @@ private static AggregateFunction extractAggregateFunction(UnresolvedExpression e
17951796
return null;
17961797
}
17971798

1799+
/**
1800+
* Collects input refs used by aggregate FILTER(WHERE ...) predicates so trimming retains them.
1801+
*/
1802+
private List<RexInputRef> getAggCallFilterRefs(
1803+
List<UnresolvedExpression> aggExprList, CalcitePlanContext context) {
1804+
List<RexInputRef> refs = new ArrayList<>();
1805+
for (UnresolvedExpression aggExpr : aggExprList) {
1806+
AggregateFunction aggFunc = extractAggregateFunction(aggExpr);
1807+
if (aggFunc != null && aggFunc.condition() != null) {
1808+
refs.addAll(PlanUtils.getInputRefs(rexVisitor.analyze(aggFunc.condition(), context)));
1809+
}
1810+
}
1811+
return refs;
1812+
}
1813+
17981814
private Optional<UnresolvedExpression> getTimeSpanField(UnresolvedExpression expr) {
17991815
if (Objects.isNull(expr)) return Optional.empty();
18001816
if (expr instanceof Span span && SpanUnit.isTimeUnit(span.getUnit())) {

0 commit comments

Comments
 (0)