Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,36 @@
import static org.opensearch.sql.ast.dsl.AstDSL.join;
import static org.opensearch.sql.ast.dsl.AstDSL.union;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.antlr.v4.runtime.tree.ParseTree;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.statement.Query;
import org.opensearch.sql.ast.statement.Statement;
import org.opensearch.sql.ast.tree.Join.JoinType;
import org.opensearch.sql.ast.tree.Project;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.sql.antlr.SQLSyntaxParser;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ExistsSubqueryExpressionAtomContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FromClauseContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.InSubqueryPredicateContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.JoinClauseContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.OrderByClauseContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.QuerySpecificationContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.UnionSelectContext;
import org.opensearch.sql.sql.parser.AstBuilder;
import org.opensearch.sql.sql.parser.AstExpressionBuilder;
import org.opensearch.sql.sql.parser.AstSortBuilder;
import org.opensearch.sql.sql.parser.AstStatementBuilder;
import org.opensearch.sql.sql.parser.context.QuerySpecification;

/** SQL query parser that produces {@link UnresolvedPlan} using the V2 ANTLR grammar. */
public class SqlV2QueryParser implements UnifiedQueryParser<UnresolvedPlan> {
Expand Down Expand Up @@ -62,6 +73,53 @@ private static class ExtendedAstBuilder extends AstBuilder {
super(query);
}

@Override
public UnresolvedPlan visitQuerySpecification(QuerySpecificationContext queryContext) {
if (!hasWindowFunctionInProjectList(queryContext)) {
return super.visitQuerySpecification(queryContext);
}

context.push();
context.peek().collect(queryContext, query);
Project project = (Project) visit(queryContext.selectClause());
UnresolvedPlan result = project.attach(visit(queryContext.fromClause()));

// Window output must be computed before ORDER BY/LIMIT, so build Limit(Sort(Project(from)))
OrderByClauseContext orderByClause = queryContext.fromClause().orderByClause();
if (orderByClause != null) {
result = new ExtendedAstSortBuilder(context.peek()).visit(orderByClause).attach(result);
}
if (queryContext.limitClause() != null) {
result = visit(queryContext.limitClause()).attach(result);
}

context.pop();
return result;
}

@Override
public UnresolvedPlan visitFromClause(FromClauseContext ctx) {
UnresolvedPlan from = super.visitFromClause(ctx);
if (hasWindowFunctionInProjectList(context.peek()) && from instanceof Sort sort) {
// Drop the ORDER BY Sort for window queries; it is re-attached above the Project
return sort.getChild().get(0);
}
return from;
}

private boolean hasWindowFunctionInProjectList(QuerySpecificationContext queryContext) {
if (queryContext.fromClause() == null) {
return false;
}
QuerySpecification probe = new QuerySpecification();
probe.collect(queryContext, query);
return hasWindowFunctionInProjectList(probe);
}

private static boolean hasWindowFunctionInProjectList(QuerySpecification querySpec) {
return querySpec.getSelectItems().stream().anyMatch(item -> item instanceof WindowFunction);
}

@Override
protected AstExpressionBuilder createExpressionBuilder() {
return new ExtendedAstExpressionBuilder();
Expand Down Expand Up @@ -114,4 +172,32 @@ public UnresolvedExpression visitExistsSubqueryExpressionAtom(
}
}
}

/**
* Keeps an ORDER BY window-alias as a column reference (Sort is above the Project) to avoid a
* second RexOver.
*/
private static class ExtendedAstSortBuilder extends AstSortBuilder {

ExtendedAstSortBuilder(QuerySpecification querySpec) {
super(querySpec);
}

@Override
public UnresolvedPlan visitOrderByClause(OrderByClauseContext ctx) {
List<Field> fields = new ArrayList<>();
List<UnresolvedExpression> items = querySpec.getOrderByItems();
List<SortOption> options = querySpec.getOrderByOptions();
for (int i = 0; i < items.size(); i++) {
UnresolvedExpression item = items.get(i);
UnresolvedExpression sortKey =
(querySpec.isSelectAlias(item)
&& querySpec.getSelectItemByAlias(item) instanceof WindowFunction)
? item
: querySpec.replaceIfAliasOrOrdinal(item);
fields.add(new Field(sortKey, createSortArguments(options.get(i))));
}
return new Sort(fields);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -547,4 +547,73 @@ SELECT LENGTH(name) FROM catalog.employees GROUP BY LENGTH(name) ORDER BY LENGTH
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testWindowOverGroupByWithLimit() {
givenQuery(
"""
SELECT department, COUNT(*) AS cnt, ROW_NUMBER() OVER (ORDER BY COUNT(*) DESC) AS rn
FROM catalog.employees GROUP BY department LIMIT 3
""")
.assertPlan(
"""
LogicalSort(fetch=[3])
LogicalProject(department=[$0], cnt=[$1], rn=[ROW_NUMBER() OVER (ORDER BY $1 DESC NULLS FIRST)])
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
LogicalProject(department=[$3])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testWindowOverGroupByOrderByWindowAlias() {
givenQuery(
"""
SELECT department, COUNT(*) AS cnt, ROW_NUMBER() OVER (ORDER BY COUNT(*) DESC) AS rn
FROM catalog.employees GROUP BY department ORDER BY rn LIMIT 3
""")
.assertPlan(
"""
LogicalSort(sort0=[$2], dir0=[ASC-nulls-first], fetch=[3])
LogicalProject(department=[$0], cnt=[$1], rn=[ROW_NUMBER() OVER (ORDER BY $1 DESC NULLS FIRST)])
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
LogicalProject(department=[$3])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testWindowOverGroupByOrderByWindowAliasWithoutLimit() {
givenQuery(
"""
SELECT department, COUNT(*) AS cnt, ROW_NUMBER() OVER (ORDER BY COUNT(*) DESC) AS rn
FROM catalog.employees GROUP BY department ORDER BY rn
""")
.assertPlan(
"""
LogicalSort(sort0=[$2], dir0=[ASC-nulls-first])
LogicalProject(department=[$0], cnt=[$1], rn=[ROW_NUMBER() OVER (ORDER BY $1 DESC NULLS FIRST)])
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
LogicalProject(department=[$3])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testMultipleWindowFunctionsOrderByWindowAlias() {
givenQuery(
"""
SELECT department, COUNT(*) AS cnt, ROW_NUMBER() OVER (ORDER BY COUNT(*) DESC) AS rn,
ROW_NUMBER() OVER (ORDER BY department) AS rn2
FROM catalog.employees GROUP BY department ORDER BY rn LIMIT 3
""")
.assertPlan(
"""
LogicalSort(sort0=[$2], dir0=[ASC-nulls-first], fetch=[3])
LogicalProject(department=[$0], cnt=[$1], rn=[ROW_NUMBER() OVER (ORDER BY $1 DESC NULLS FIRST)], rn2=[ROW_NUMBER() OVER (ORDER BY $0 NULLS FIRST)])
LogicalAggregate(group=[{0}], COUNT(*)=[COUNT()])
LogicalProject(department=[$3])
LogicalTableScan(table=[[catalog, employees]])
""");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ public class AstBuilder extends OpenSearchSQLParserBaseVisitor<UnresolvedPlan> {
private final AstExpressionBuilder expressionBuilder;

/** Parsing context stack that contains context for current query parsing. */
private final ParsingContext context = new ParsingContext();
protected final ParsingContext context = new ParsingContext();

/**
* SQL query to get original token text. This is necessary because token.getText() returns text
* without whitespaces or other characters discarded by lexer.
*/
private final String query;
protected final String query;

public AstBuilder(String query) {
this.query = query;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
@RequiredArgsConstructor
public class AstSortBuilder extends OpenSearchSQLParserBaseVisitor<UnresolvedPlan> {

private final QuerySpecification querySpec;
protected final QuerySpecification querySpec;

@Override
public UnresolvedPlan visitOrderByClause(OrderByClauseContext ctx) {
Expand All @@ -57,7 +57,7 @@ private List<Field> createSortFields() {
* Argument "asc" is required. Argument "nullFirst" is optional and determined by Analyzer later
* if absent.
*/
private List<Argument> createSortArguments(SortOption option) {
protected List<Argument> createSortArguments(SortOption option) {
SortOrder sortOrder = option.getSortOrder();
NullOrder nullOrder = option.getNullOrder();
ImmutableList.Builder<Argument> args = ImmutableList.builder();
Expand Down
Loading