Skip to content

Commit d9b031f

Browse files
committed
fix: Reference group-key column for GROUP BY expressions
A SELECT/HAVING/ORDER BY expression matching a GROUP BY expression now resolves to the materialized group-key column via a context-scoped index checked in visitFunction, instead of recomputing it from base fields the aggregation removed (which failed with "Field not found"). Signed-off-by: Chen Dai <daichen@amazon.com>
1 parent 7f31510 commit d9b031f

4 files changed

Lines changed: 65 additions & 0 deletions

File tree

api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,4 +506,45 @@ SELECT name, ROW_NUMBER() OVER (ORDER BY id) AS rn FROM catalog.employees
506506
LogicalTableScan(table=[[catalog, employees]])
507507
""");
508508
}
509+
510+
@Test
511+
public void testGroupByExpression() {
512+
givenQuery("SELECT LENGTH(name), COUNT(*) FROM catalog.employees GROUP BY LENGTH(name)")
513+
.assertPlan(
514+
"""
515+
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
516+
LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)])
517+
LogicalTableScan(table=[[catalog, employees]])
518+
""");
519+
}
520+
521+
@Test
522+
public void testHavingOnGroupByExpression() {
523+
givenQuery(
524+
"SELECT COUNT(*) FROM catalog.employees GROUP BY LENGTH(name) HAVING LENGTH(name) > 3")
525+
.assertPlan(
526+
"""
527+
LogicalProject(COUNT(*)=[$0])
528+
LogicalFilter(condition=[>($1, 3)])
529+
LogicalProject(COUNT(*)=[$1], LENGTH(name)=[$0])
530+
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
531+
LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)])
532+
LogicalTableScan(table=[[catalog, employees]])
533+
""");
534+
}
535+
536+
@Test
537+
public void testOrderByGroupByExpression() {
538+
givenQuery(
539+
"""
540+
SELECT LENGTH(name) FROM catalog.employees GROUP BY LENGTH(name) ORDER BY LENGTH(name)
541+
""")
542+
.assertPlan(
543+
"""
544+
LogicalSort(sort0=[$0], dir0=[ASC-nulls-first])
545+
LogicalAggregate(group=[{0}])
546+
LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)])
547+
LogicalTableScan(table=[[catalog, employees]])
548+
""");
549+
}
509550
}

core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.calcite.rex.RexNode;
2323
import org.apache.calcite.tools.FrameworkConfig;
2424
import org.opensearch.sql.ast.expression.AggregateFunction;
25+
import org.opensearch.sql.ast.expression.Function;
2526
import org.opensearch.sql.ast.expression.UnresolvedExpression;
2627
import org.opensearch.sql.ast.tree.HighlightConfig;
2728
import org.opensearch.sql.calcite.utils.CalciteToolsHelper;
@@ -71,6 +72,9 @@ public class CalcitePlanContext {
7172
*/
7273
@Getter private final Map<AggregateFunction, Integer> aggregateOutputIndex = new HashMap<>();
7374

75+
/** Maps GROUP BY Function AST nodes to their output field index for post-aggregate resolution. */
76+
@Getter private final Map<Function, Integer> groupKeyOutputIndex = new HashMap<>();
77+
7478
/**
7579
* List of captured variables from outer scope for lambda functions. When a lambda body references
7680
* a field that is not a lambda parameter, it gets captured and stored here. The captured

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,6 +1788,14 @@ private void visitAggregation(
17881788
context.getAggregateOutputIndex().put(aggFunc, aggStartIdx + i);
17891789
}
17901790
}
1791+
context.getGroupKeyOutputIndex().clear();
1792+
int groupStartIdx = metricsFirst ? aggRexList.size() : 0;
1793+
for (int i = 0; i < groupExprList.size(); i++) {
1794+
Function groupFunc = extractFunction(groupExprList.get(i));
1795+
if (groupFunc != null) {
1796+
context.getGroupKeyOutputIndex().put(groupFunc, groupStartIdx + i);
1797+
}
1798+
}
17911799
}
17921800

17931801
private static AggregateFunction extractAggregateFunction(UnresolvedExpression expr) {
@@ -1796,6 +1804,12 @@ private static AggregateFunction extractAggregateFunction(UnresolvedExpression e
17961804
return null;
17971805
}
17981806

1807+
private static Function extractFunction(UnresolvedExpression expr) {
1808+
if (expr instanceof Function f) return f;
1809+
if (expr instanceof Alias alias) return extractFunction(alias.getDelegated());
1810+
return null;
1811+
}
1812+
17991813
/**
18001814
* Collects input refs used by aggregate FILTER(WHERE ...) predicates so trimming retains them.
18011815
*/

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,12 @@ private List<RelDataType> modifyLambdaTypeByFunction(
563563

564564
@Override
565565
public RexNode visitFunction(Function node, CalcitePlanContext context) {
566+
// Post-aggregate, a GROUP BY function expression is a materialized output column; reference it
567+
// instead of recomputing from base fields the aggregation removed.
568+
Integer groupKeyIndex = context.getGroupKeyOutputIndex().get(node);
569+
if (groupKeyIndex != null) {
570+
return context.relBuilder.field(groupKeyIndex);
571+
}
566572
List<UnresolvedExpression> args = node.getFuncArgs();
567573
List<RexNode> arguments = new ArrayList<>();
568574

0 commit comments

Comments
 (0)