Skip to content

Commit 6bd1d77

Browse files
committed
fix: Reference group-key column for GROUP BY expressions
A SELECT/HAVING/ORDER BY expression matching a GROUP BY expression now resolves to the materialized group-key column via a context-scoped index checked in visitFunction, instead of recomputing it from base fields the aggregation removed (which failed with "Field not found"). Signed-off-by: Chen Dai <daichen@amazon.com>
1 parent 7f31510 commit 6bd1d77

4 files changed

Lines changed: 67 additions & 0 deletions

File tree

api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,4 +506,45 @@ SELECT name, ROW_NUMBER() OVER (ORDER BY id) AS rn FROM catalog.employees
506506
LogicalTableScan(table=[[catalog, employees]])
507507
""");
508508
}
509+
510+
@Test
511+
public void testGroupByExpression() {
512+
givenQuery("SELECT LENGTH(name), COUNT(*) FROM catalog.employees GROUP BY LENGTH(name)")
513+
.assertPlan(
514+
"""
515+
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
516+
LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)])
517+
LogicalTableScan(table=[[catalog, employees]])
518+
""");
519+
}
520+
521+
@Test
522+
public void testHavingOnGroupByExpression() {
523+
givenQuery(
524+
"SELECT COUNT(*) FROM catalog.employees GROUP BY LENGTH(name) HAVING LENGTH(name) > 3")
525+
.assertPlan(
526+
"""
527+
LogicalProject(COUNT(*)=[$0])
528+
LogicalFilter(condition=[>($1, 3)])
529+
LogicalProject(COUNT(*)=[$1], LENGTH(name)=[$0])
530+
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
531+
LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)])
532+
LogicalTableScan(table=[[catalog, employees]])
533+
""");
534+
}
535+
536+
@Test
537+
public void testOrderByGroupByExpression() {
538+
givenQuery(
539+
"""
540+
SELECT LENGTH(name) FROM catalog.employees GROUP BY LENGTH(name) ORDER BY LENGTH(name)
541+
""")
542+
.assertPlan(
543+
"""
544+
LogicalSort(sort0=[$0], dir0=[ASC-nulls-first])
545+
LogicalAggregate(group=[{0}])
546+
LogicalProject(LENGTH(name)=[CHAR_LENGTH($1)])
547+
LogicalTableScan(table=[[catalog, employees]])
548+
""");
549+
}
509550
}

core/src/main/java/org/opensearch/sql/calcite/CalcitePlanContext.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.calcite.rex.RexNode;
2323
import org.apache.calcite.tools.FrameworkConfig;
2424
import org.opensearch.sql.ast.expression.AggregateFunction;
25+
import org.opensearch.sql.ast.expression.Function;
2526
import org.opensearch.sql.ast.expression.UnresolvedExpression;
2627
import org.opensearch.sql.ast.tree.HighlightConfig;
2728
import org.opensearch.sql.calcite.utils.CalciteToolsHelper;
@@ -71,6 +72,9 @@ public class CalcitePlanContext {
7172
*/
7273
@Getter private final Map<AggregateFunction, Integer> aggregateOutputIndex = new HashMap<>();
7374

75+
/** Maps GROUP BY Function AST nodes to their output field index for post-aggregate resolution. */
76+
@Getter private final Map<Function, Integer> groupKeyOutputIndex = new HashMap<>();
77+
7478
/**
7579
* List of captured variables from outer scope for lambda functions. When a lambda body references
7680
* a field that is not a lambda parameter, it gets captured and stored here. The captured

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,6 +1788,17 @@ private void visitAggregation(
17881788
context.getAggregateOutputIndex().put(aggFunc, aggStartIdx + i);
17891789
}
17901790
}
1791+
1792+
// Register group-by expression output indices so post-aggregate references resolve to them;
1793+
// clear() safe as above.
1794+
context.getGroupKeyOutputIndex().clear();
1795+
int groupStartIdx = metricsFirst ? aggRexList.size() : 0;
1796+
for (int i = 0; i < groupExprList.size(); i++) {
1797+
Function groupFunc = extractFunction(groupExprList.get(i));
1798+
if (groupFunc != null) {
1799+
context.getGroupKeyOutputIndex().put(groupFunc, groupStartIdx + i);
1800+
}
1801+
}
17911802
}
17921803

17931804
private static AggregateFunction extractAggregateFunction(UnresolvedExpression expr) {
@@ -1796,6 +1807,12 @@ private static AggregateFunction extractAggregateFunction(UnresolvedExpression e
17961807
return null;
17971808
}
17981809

1810+
private static Function extractFunction(UnresolvedExpression expr) {
1811+
if (expr instanceof Function f) return f;
1812+
if (expr instanceof Alias alias) return extractFunction(alias.getDelegated());
1813+
return null;
1814+
}
1815+
17991816
/**
18001817
* Collects input refs used by aggregate FILTER(WHERE ...) predicates so trimming retains them.
18011818
*/

core/src/main/java/org/opensearch/sql/calcite/CalciteRexNodeVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,11 @@ private List<RelDataType> modifyLambdaTypeByFunction(
563563

564564
@Override
565565
public RexNode visitFunction(Function node, CalcitePlanContext context) {
566+
// Resolve a group-by expression to its group-key output index.
567+
Integer groupKeyIndex = context.getGroupKeyOutputIndex().get(node);
568+
if (groupKeyIndex != null) {
569+
return context.relBuilder.field(groupKeyIndex);
570+
}
566571
List<UnresolvedExpression> args = node.getFuncArgs();
567572
List<RexNode> arguments = new ArrayList<>();
568573

0 commit comments

Comments
 (0)