Skip to content

Commit b10d541

Browse files
authored
feat(sql): add IN/EXISTS subquery support in unified query path (opensearch-project#5448)
Add grammar rules (inSubqueryPredicate, existsSubqueryExpressionAtom) and wire them through ExtendedAstExpressionBuilder to produce InSubquery and ExistsSubquery AST nodes for the Calcite-based unified query path. Base AstExpressionBuilder throws SyntaxCheckException to preserve legacy engine fallback. AstBuilder now uses createExpressionBuilder() factory method to allow subclass customization. Also add Alias handling in CalciteRelNodeVisitor.expandProjectFields required for any non-SELECT * query in the unified path. Signed-off-by: Chen Dai <daichen@amazon.com>
1 parent bd99329 commit b10d541

9 files changed

Lines changed: 174 additions & 3 deletions

File tree

api/src/main/java/org/opensearch/sql/api/parser/SqlV2QueryParser.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,25 @@
55

66
package org.opensearch.sql.api.parser;
77

8+
import static org.opensearch.sql.ast.dsl.AstDSL.existsSubquery;
9+
import static org.opensearch.sql.ast.dsl.AstDSL.inSubquery;
810
import static org.opensearch.sql.ast.dsl.AstDSL.join;
911

1012
import java.util.Optional;
1113
import org.antlr.v4.runtime.tree.ParseTree;
14+
import org.opensearch.sql.ast.expression.Not;
1215
import org.opensearch.sql.ast.expression.UnresolvedExpression;
1316
import org.opensearch.sql.ast.statement.Query;
1417
import org.opensearch.sql.ast.statement.Statement;
1518
import org.opensearch.sql.ast.tree.Join.JoinType;
1619
import org.opensearch.sql.ast.tree.UnresolvedPlan;
1720
import org.opensearch.sql.sql.antlr.SQLSyntaxParser;
1821
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser;
22+
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ExistsSubqueryExpressionAtomContext;
23+
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.InSubqueryPredicateContext;
1924
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.JoinClauseContext;
2025
import org.opensearch.sql.sql.parser.AstBuilder;
26+
import org.opensearch.sql.sql.parser.AstExpressionBuilder;
2127
import org.opensearch.sql.sql.parser.AstStatementBuilder;
2228

2329
/** SQL query parser that produces {@link UnresolvedPlan} using the V2 ANTLR grammar. */
@@ -52,6 +58,11 @@ private static class ExtendedAstBuilder extends AstBuilder {
5258
super(query);
5359
}
5460

61+
@Override
62+
protected AstExpressionBuilder createExpressionBuilder() {
63+
return new ExtendedAstExpressionBuilder();
64+
}
65+
5566
@Override
5667
public UnresolvedPlan visitJoinClause(JoinClauseContext ctx) {
5768
JoinType joinType = toJoinType(ctx);
@@ -69,5 +80,27 @@ private JoinType toJoinType(JoinClauseContext ctx) {
6980
default -> JoinType.INNER;
7081
};
7182
}
83+
84+
/**
85+
* Expression builder with IN/EXISTS subquery support. Accesses the enclosing AstBuilder to
86+
* visit subquery plan nodes. Must be created via {@link #createExpressionBuilder()} because the
87+
* enclosing {@code this} reference is not available during {@code super()} construction.
88+
*/
89+
private class ExtendedAstExpressionBuilder extends AstExpressionBuilder {
90+
91+
@Override
92+
public UnresolvedExpression visitInSubqueryPredicate(InSubqueryPredicateContext ctx) {
93+
UnresolvedPlan subquery = ExtendedAstBuilder.this.visit(ctx.querySpecification());
94+
UnresolvedExpression inExpr = inSubquery(subquery, visit(ctx.predicate()));
95+
return (ctx.NOT() != null) ? new Not(inExpr) : inExpr;
96+
}
97+
98+
@Override
99+
public UnresolvedExpression visitExistsSubqueryExpressionAtom(
100+
ExistsSubqueryExpressionAtomContext ctx) {
101+
UnresolvedPlan subquery = ExtendedAstBuilder.this.visit(ctx.querySpecification());
102+
return existsSubquery(subquery);
103+
}
104+
}
72105
}
73106
}

api/src/test/java/org/opensearch/sql/api/UnifiedQueryPlannerSqlV2Test.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,80 @@ public void testJoinWithFilterAndOrderBy() {
142142
LogicalTableScan(table=[[catalog, departments]])
143143
""");
144144
}
145+
146+
@Test
147+
public void testInSubquery() {
148+
givenQuery(
149+
"""
150+
SELECT name FROM catalog.employees
151+
WHERE age IN (SELECT age FROM catalog.departments WHERE dept_name = 'Engineering')
152+
""")
153+
.assertPlan(
154+
"""
155+
LogicalProject(name=[$1])
156+
LogicalFilter(condition=[IN($2, {
157+
LogicalProject(age=[$cor0.age])
158+
LogicalFilter(condition=[=($1, 'Engineering')])
159+
LogicalTableScan(table=[[catalog, departments]])
160+
})], variablesSet=[[$cor0]])
161+
LogicalTableScan(table=[[catalog, employees]])
162+
""");
163+
}
164+
165+
@Test
166+
public void testExistsSubquery() {
167+
givenQuery(
168+
"""
169+
SELECT name FROM catalog.employees
170+
WHERE EXISTS (SELECT 1 FROM catalog.departments WHERE dept_id = age)
171+
""")
172+
.assertPlan(
173+
"""
174+
LogicalProject(name=[$1])
175+
LogicalFilter(condition=[EXISTS({
176+
LogicalProject(1=[1])
177+
LogicalFilter(condition=[=($0, $cor0.age)])
178+
LogicalTableScan(table=[[catalog, departments]])
179+
})], variablesSet=[[$cor0]])
180+
LogicalTableScan(table=[[catalog, employees]])
181+
""");
182+
}
183+
184+
@Test
185+
public void testNotInSubquery() {
186+
givenQuery(
187+
"""
188+
SELECT name FROM catalog.employees
189+
WHERE age NOT IN (SELECT age FROM catalog.departments WHERE dept_name = 'Engineering')
190+
""")
191+
.assertPlan(
192+
"""
193+
LogicalProject(name=[$1])
194+
LogicalFilter(condition=[NOT(IN($2, {
195+
LogicalProject(age=[$cor0.age])
196+
LogicalFilter(condition=[=($1, 'Engineering')])
197+
LogicalTableScan(table=[[catalog, departments]])
198+
}))], variablesSet=[[$cor0]])
199+
LogicalTableScan(table=[[catalog, employees]])
200+
""");
201+
}
202+
203+
@Test
204+
public void testNotExistsSubquery() {
205+
givenQuery(
206+
"""
207+
SELECT name FROM catalog.employees
208+
WHERE NOT EXISTS (SELECT 1 FROM catalog.departments WHERE dept_id = age)
209+
""")
210+
.assertPlan(
211+
"""
212+
LogicalProject(name=[$1])
213+
LogicalFilter(condition=[NOT(EXISTS({
214+
LogicalProject(1=[1])
215+
LogicalFilter(condition=[=($0, $cor0.age)])
216+
LogicalTableScan(table=[[catalog, departments]])
217+
}))], variablesSet=[[$cor0]])
218+
LogicalTableScan(table=[[catalog, employees]])
219+
""");
220+
}
145221
}

core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
import org.opensearch.sql.ast.expression.When;
4949
import org.opensearch.sql.ast.expression.WindowFunction;
5050
import org.opensearch.sql.ast.expression.Xor;
51+
import org.opensearch.sql.ast.expression.subquery.ExistsSubquery;
52+
import org.opensearch.sql.ast.expression.subquery.InSubquery;
5153
import org.opensearch.sql.ast.tree.Aggregation;
5254
import org.opensearch.sql.ast.tree.AppendPipe;
5355
import org.opensearch.sql.ast.tree.Bin;
@@ -771,4 +773,12 @@ public static UnresolvedPlan join(
771773
Optional.empty(),
772774
Argument.ArgumentMap.empty());
773775
}
776+
777+
public static InSubquery inSubquery(UnresolvedPlan query, UnresolvedExpression... values) {
778+
return new InSubquery(List.of(values), query);
779+
}
780+
781+
public static ExistsSubquery existsSubquery(UnresolvedPlan query) {
782+
return new ExistsSubquery(query);
783+
}
774784
}

core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,9 @@ private List<RexNode> expandProjectFields(
541541
.filter(addedFields::add)
542542
.forEach(field -> expandedFields.add(context.relBuilder.field(field)));
543543
}
544+
case Alias alias -> {
545+
expandedFields.add(rexVisitor.analyze(alias, context));
546+
}
544547
default ->
545548
throw new IllegalStateException(
546549
"Unexpected expression type in project list: " + expr.getClass().getSimpleName());

integ-test/src/test/java/org/opensearch/sql/legacy/SqlLegacyEngineSanityIT.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,13 @@ public void testLeftJoinFallback() throws IOException {
4545
.formatted(TEST_INDEX_PEOPLE, TEST_INDEX_DOG));
4646
verifyDataRows(result, rows("Daenerys", "rex"));
4747
}
48+
49+
@Test
50+
public void testInSubqueryFallback() throws IOException {
51+
JSONObject result =
52+
executeQuery(
53+
"SELECT a.firstname FROM %s a WHERE a.firstname IN (SELECT holdersName FROM %s)"
54+
.formatted(TEST_INDEX_PEOPLE, TEST_INDEX_DOG));
55+
verifyDataRows(result, rows("Daenerys"), rows("Hattie"));
56+
}
4857
}

sql/src/main/antlr/OpenSearchSQLParser.g4

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ predicate
322322
| left = predicate NOT? LIKE right = predicate # likePredicate
323323
| left = predicate REGEXP right = predicate # regexpPredicate
324324
| predicate NOT? IN '(' expressions ')' # inPredicate
325+
| predicate NOT? IN '(' querySpecification ')' # inSubqueryPredicate
325326
;
326327

327328
expressions
@@ -333,6 +334,7 @@ expressionAtom
333334
| columnName # fullColumnNameExpressionAtom
334335
| functionCall # functionCallExpressionAtom
335336
| LR_BRACKET expression RR_BRACKET # nestedExpressionAtom
337+
| EXISTS LR_BRACKET querySpecification RR_BRACKET # existsSubqueryExpressionAtom
336338
| left = expressionAtom mathOperator = (STAR | SLASH | MODULE) right = expressionAtom # mathExpressionAtom
337339
| left = expressionAtom mathOperator = (PLUS | MINUS) right = expressionAtom # mathExpressionAtom
338340
;

sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import java.util.Collections;
2424
import java.util.Locale;
2525
import java.util.Optional;
26-
import lombok.RequiredArgsConstructor;
2726
import org.antlr.v4.runtime.tree.ParseTree;
2827
import org.opensearch.sql.ast.expression.Alias;
2928
import org.opensearch.sql.ast.expression.AllFields;
@@ -50,10 +49,9 @@
5049
import org.opensearch.sql.sql.parser.context.ParsingContext;
5150

5251
/** Abstract syntax tree (AST) builder. */
53-
@RequiredArgsConstructor
5452
public class AstBuilder extends OpenSearchSQLParserBaseVisitor<UnresolvedPlan> {
5553

56-
private final AstExpressionBuilder expressionBuilder = new AstExpressionBuilder();
54+
private final AstExpressionBuilder expressionBuilder;
5755

5856
/** Parsing context stack that contains context for current query parsing. */
5957
private final ParsingContext context = new ParsingContext();
@@ -64,6 +62,11 @@ public class AstBuilder extends OpenSearchSQLParserBaseVisitor<UnresolvedPlan> {
6462
*/
6563
private final String query;
6664

65+
public AstBuilder(String query) {
66+
this.query = query;
67+
this.expressionBuilder = createExpressionBuilder();
68+
}
69+
6770
@Override
6871
public UnresolvedPlan visitShowStatement(OpenSearchSQLParser.ShowStatementContext ctx) {
6972
final UnresolvedExpression tableFilter = visitAstExpression(ctx.tableFilter());
@@ -279,6 +282,11 @@ protected UnresolvedExpression visitAstExpression(ParseTree tree) {
279282
return expressionBuilder.visit(tree);
280283
}
281284

285+
/** Override to provide a custom expression builder (e.g., with subquery support). */
286+
protected AstExpressionBuilder createExpressionBuilder() {
287+
return new AstExpressionBuilder();
288+
}
289+
282290
private UnresolvedExpression visitSelectItem(SelectElementContext ctx) {
283291
String name = StringUtils.unquoteIdentifier(getTextInQuery(ctx.expression(), query));
284292
UnresolvedExpression expr = visitAstExpression(ctx.expression());

sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@
2828
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DataTypeFunctionCallContext;
2929
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DateLiteralContext;
3030
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DistinctCountFunctionCallContext;
31+
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ExistsSubqueryExpressionAtomContext;
3132
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ExtractFunctionCallContext;
3233
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FilterClauseContext;
3334
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FilteredAggregationFunctionCallContext;
3435
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FunctionArgContext;
3536
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.GetFormatFunctionCallContext;
3637
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.HighlightFunctionCallContext;
3738
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.InPredicateContext;
39+
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.InSubqueryPredicateContext;
3840
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IsNullPredicateContext;
3941
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.LikePredicateContext;
4042
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MathExpressionAtomContext;
@@ -82,6 +84,7 @@
8284
import org.opensearch.sql.ast.dsl.AstDSL;
8385
import org.opensearch.sql.ast.expression.*;
8486
import org.opensearch.sql.ast.tree.Sort.SortOption;
87+
import org.opensearch.sql.common.antlr.SyntaxCheckException;
8588
import org.opensearch.sql.common.utils.StringUtils;
8689
import org.opensearch.sql.expression.function.BuiltinFunctionName;
8790
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser;
@@ -668,4 +671,17 @@ private List<UnresolvedExpression> getExtractFunctionArguments(ExtractFunctionCa
668671
visitFunctionArg(ctx.extractFunction().functionArg()));
669672
return args;
670673
}
674+
675+
@Override
676+
public UnresolvedExpression visitInSubqueryPredicate(InSubqueryPredicateContext ctx) {
677+
throw new SyntaxCheckException(
678+
"IN subquery is not supported in the V2 SQL engine. Falling back to legacy engine.");
679+
}
680+
681+
@Override
682+
public UnresolvedExpression visitExistsSubqueryExpressionAtom(
683+
ExistsSubqueryExpressionAtomContext ctx) {
684+
throw new SyntaxCheckException(
685+
"EXISTS subquery is not supported in the V2 SQL engine. Falling back to legacy engine.");
686+
}
671687
}

sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,4 +756,18 @@ public UnresolvedPlan visitJoinClause(OpenSearchSQLParser.JoinClauseContext ctx)
756756
};
757757
assertNotNull(new SQLSyntaxParser().parse(query).accept(builder));
758758
}
759+
760+
@Test
761+
public void in_subquery_throws_syntax_check_exception() {
762+
assertThrows(
763+
SyntaxCheckException.class,
764+
() -> buildAST("SELECT * FROM t WHERE age IN (SELECT age FROM t2)"));
765+
}
766+
767+
@Test
768+
public void exists_subquery_throws_syntax_check_exception() {
769+
assertThrows(
770+
SyntaxCheckException.class,
771+
() -> buildAST("SELECT * FROM t WHERE EXISTS (SELECT 1 FROM t2)"));
772+
}
759773
}

0 commit comments

Comments
 (0)