diff --git a/api/src/main/java/org/opensearch/sql/api/parser/SqlV2QueryParser.java b/api/src/main/java/org/opensearch/sql/api/parser/SqlV2QueryParser.java index d60ef7ed4ec..1078c58bcc1 100644 --- a/api/src/main/java/org/opensearch/sql/api/parser/SqlV2QueryParser.java +++ b/api/src/main/java/org/opensearch/sql/api/parser/SqlV2QueryParser.java @@ -10,25 +10,36 @@ import static org.opensearch.sql.ast.dsl.AstDSL.join; import static org.opensearch.sql.ast.dsl.AstDSL.union; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; import org.antlr.v4.runtime.tree.ParseTree; +import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.statement.Query; import org.opensearch.sql.ast.statement.Statement; import org.opensearch.sql.ast.tree.Join.JoinType; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.sql.antlr.SQLSyntaxParser; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ExistsSubqueryExpressionAtomContext; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FromClauseContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.InSubqueryPredicateContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.JoinClauseContext; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.OrderByClauseContext; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.QuerySpecificationContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.UnionSelectContext; import org.opensearch.sql.sql.parser.AstBuilder; import org.opensearch.sql.sql.parser.AstExpressionBuilder; +import org.opensearch.sql.sql.parser.AstSortBuilder; import org.opensearch.sql.sql.parser.AstStatementBuilder; +import org.opensearch.sql.sql.parser.context.QuerySpecification; /** SQL query parser that produces {@link UnresolvedPlan} using the V2 ANTLR grammar. */ public class SqlV2QueryParser implements UnifiedQueryParser { @@ -62,6 +73,53 @@ private static class ExtendedAstBuilder extends AstBuilder { super(query); } + @Override + public UnresolvedPlan visitQuerySpecification(QuerySpecificationContext queryContext) { + if (!hasWindowFunctionInProjectList(queryContext)) { + return super.visitQuerySpecification(queryContext); + } + + context.push(); + context.peek().collect(queryContext, query); + Project project = (Project) visit(queryContext.selectClause()); + UnresolvedPlan result = project.attach(visit(queryContext.fromClause())); + + // Window output must be computed before ORDER BY/LIMIT, so build Limit(Sort(Project(from))) + OrderByClauseContext orderByClause = queryContext.fromClause().orderByClause(); + if (orderByClause != null) { + result = new ExtendedAstSortBuilder(context.peek()).visit(orderByClause).attach(result); + } + if (queryContext.limitClause() != null) { + result = visit(queryContext.limitClause()).attach(result); + } + + context.pop(); + return result; + } + + @Override + public UnresolvedPlan visitFromClause(FromClauseContext ctx) { + UnresolvedPlan from = super.visitFromClause(ctx); + if (hasWindowFunctionInProjectList(context.peek()) && from instanceof Sort sort) { + // Drop the ORDER BY Sort for window queries; it is re-attached above the Project + return sort.getChild().get(0); + } + return from; + } + + private boolean hasWindowFunctionInProjectList(QuerySpecificationContext queryContext) { + if (queryContext.fromClause() == null) { + return false; + } + QuerySpecification probe = new QuerySpecification(); + probe.collect(queryContext, query); + return hasWindowFunctionInProjectList(probe); + } + + private static boolean hasWindowFunctionInProjectList(QuerySpecification querySpec) { + return querySpec.getSelectItems().stream().anyMatch(item -> item instanceof WindowFunction); + } + @Override protected AstExpressionBuilder createExpressionBuilder() { return new ExtendedAstExpressionBuilder(); @@ -114,4 +172,32 @@ public UnresolvedExpression visitExistsSubqueryExpressionAtom( } } } + + /** + * Keeps an ORDER BY window-alias as a column reference (Sort is above the Project) to avoid a + * second RexOver. + */ + private static class ExtendedAstSortBuilder extends AstSortBuilder { + + ExtendedAstSortBuilder(QuerySpecification querySpec) { + super(querySpec); + } + + @Override + public UnresolvedPlan visitOrderByClause(OrderByClauseContext ctx) { + List fields = new ArrayList<>(); + List items = querySpec.getOrderByItems(); + List options = querySpec.getOrderByOptions(); + for (int i = 0; i < items.size(); i++) { + UnresolvedExpression item = items.get(i); + UnresolvedExpression sortKey = + (querySpec.isSelectAlias(item) + && querySpec.getSelectItemByAlias(item) instanceof WindowFunction) + ? item + : querySpec.replaceIfAliasOrOrdinal(item); + fields.add(new Field(sortKey, createSortArguments(options.get(i)))); + } + return new Sort(fields); + } + } } diff --git a/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java b/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java index 8538ebbae61..064eff32d76 100644 --- a/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java +++ b/api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java @@ -547,4 +547,73 @@ SELECT LENGTH(name) FROM catalog.employees GROUP BY LENGTH(name) ORDER BY LENGTH LogicalTableScan(table=[[catalog, employees]]) """); } + + @Test + public void testWindowOverGroupByWithLimit() { + givenQuery( + """ + SELECT department, COUNT(*) AS cnt, ROW_NUMBER() OVER (ORDER BY COUNT(*) DESC) AS rn + FROM catalog.employees GROUP BY department LIMIT 3 + """) + .assertPlan( + """ + LogicalSort(fetch=[3]) + LogicalProject(department=[$0], cnt=[$1], rn=[ROW_NUMBER() OVER (ORDER BY $1 DESC NULLS FIRST)]) + LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()]) + LogicalProject(department=[$3]) + LogicalTableScan(table=[[catalog, employees]]) + """); + } + + @Test + public void testWindowOverGroupByOrderByWindowAlias() { + givenQuery( + """ + SELECT department, COUNT(*) AS cnt, ROW_NUMBER() OVER (ORDER BY COUNT(*) DESC) AS rn + FROM catalog.employees GROUP BY department ORDER BY rn LIMIT 3 + """) + .assertPlan( + """ + LogicalSort(sort0=[$2], dir0=[ASC-nulls-first], fetch=[3]) + LogicalProject(department=[$0], cnt=[$1], rn=[ROW_NUMBER() OVER (ORDER BY $1 DESC NULLS FIRST)]) + LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()]) + LogicalProject(department=[$3]) + LogicalTableScan(table=[[catalog, employees]]) + """); + } + + @Test + public void testWindowOverGroupByOrderByWindowAliasWithoutLimit() { + givenQuery( + """ + SELECT department, COUNT(*) AS cnt, ROW_NUMBER() OVER (ORDER BY COUNT(*) DESC) AS rn + FROM catalog.employees GROUP BY department ORDER BY rn + """) + .assertPlan( + """ + LogicalSort(sort0=[$2], dir0=[ASC-nulls-first]) + LogicalProject(department=[$0], cnt=[$1], rn=[ROW_NUMBER() OVER (ORDER BY $1 DESC NULLS FIRST)]) + LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()]) + LogicalProject(department=[$3]) + LogicalTableScan(table=[[catalog, employees]]) + """); + } + + @Test + public void testMultipleWindowFunctionsOrderByWindowAlias() { + givenQuery( + """ + SELECT department, COUNT(*) AS cnt, ROW_NUMBER() OVER (ORDER BY COUNT(*) DESC) AS rn, + ROW_NUMBER() OVER (ORDER BY department) AS rn2 + FROM catalog.employees GROUP BY department ORDER BY rn LIMIT 3 + """) + .assertPlan( + """ + LogicalSort(sort0=[$2], dir0=[ASC-nulls-first], fetch=[3]) + LogicalProject(department=[$0], cnt=[$1], rn=[ROW_NUMBER() OVER (ORDER BY $1 DESC NULLS FIRST)], rn2=[ROW_NUMBER() OVER (ORDER BY $0 NULLS FIRST)]) + LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()]) + LogicalProject(department=[$3]) + LogicalTableScan(table=[[catalog, employees]]) + """); + } } diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java index aaed2ba5ec2..641ef0d39ca 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java @@ -54,13 +54,13 @@ public class AstBuilder extends OpenSearchSQLParserBaseVisitor { private final AstExpressionBuilder expressionBuilder; /** Parsing context stack that contains context for current query parsing. */ - private final ParsingContext context = new ParsingContext(); + protected final ParsingContext context = new ParsingContext(); /** * SQL query to get original token text. This is necessary because token.getText() returns text * without whitespaces or other characters discarded by lexer. */ - private final String query; + protected final String query; public AstBuilder(String query) { this.query = query; diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstSortBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstSortBuilder.java index 2594709f4f4..1647bc9ee9e 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstSortBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstSortBuilder.java @@ -33,7 +33,7 @@ @RequiredArgsConstructor public class AstSortBuilder extends OpenSearchSQLParserBaseVisitor { - private final QuerySpecification querySpec; + protected final QuerySpecification querySpec; @Override public UnresolvedPlan visitOrderByClause(OrderByClauseContext ctx) { @@ -57,7 +57,7 @@ private List createSortFields() { * Argument "asc" is required. Argument "nullFirst" is optional and determined by Analyzer later * if absent. */ - private List createSortArguments(SortOption option) { + protected List createSortArguments(SortOption option) { SortOrder sortOrder = option.getSortOrder(); NullOrder nullOrder = option.getNullOrder(); ImmutableList.Builder args = ImmutableList.builder();