Skip to content

Commit 5ed2a28

Browse files
Implement SQL validation based on grammar element (#3039) (#3044)
* Implement SQL validation based on grammar element * Add function types * fix style * Add security lake * Add File support * Integrate into SparkQueryDispatcher * Fix style * Add tests * Integration * Add comments * Address comments * Allow join types for now * Fix style * Fix coverage check --------- (cherry picked from commit a87893a) Signed-off-by: Tomoyuki Morita <moritato@amazon.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent a20f655 commit 5ed2a28

22 files changed

Lines changed: 2195 additions & 192 deletions

async-query-core/build.gradle

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ jacocoTestCoverageVerification {
122122
'org.opensearch.sql.spark.flint.*',
123123
'org.opensearch.sql.spark.flint.operation.*',
124124
'org.opensearch.sql.spark.rest.*',
125-
'org.opensearch.sql.spark.utils.SQLQueryUtils.*'
125+
'org.opensearch.sql.spark.utils.SQLQueryUtils.*',
126+
'org.opensearch.sql.spark.validator.SQLQueryValidationVisitor'
126127
]
127128
limit {
128129
counter = 'LINE'

async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.sql.spark.dispatcher;
77

88
import java.util.HashMap;
9-
import java.util.List;
109
import java.util.Map;
1110
import lombok.AllArgsConstructor;
1211
import org.jetbrains.annotations.NotNull;
@@ -24,6 +23,7 @@
2423
import org.opensearch.sql.spark.execution.session.SessionManager;
2524
import org.opensearch.sql.spark.rest.model.LangType;
2625
import org.opensearch.sql.spark.utils.SQLQueryUtils;
26+
import org.opensearch.sql.spark.validator.SQLQueryValidator;
2727

2828
/** This class takes care of understanding query and dispatching job query to emr serverless. */
2929
@AllArgsConstructor
@@ -38,6 +38,7 @@ public class SparkQueryDispatcher {
3838
private final SessionManager sessionManager;
3939
private final QueryHandlerFactory queryHandlerFactory;
4040
private final QueryIdProvider queryIdProvider;
41+
private final SQLQueryValidator sqlQueryValidator;
4142

4243
public DispatchQueryResponse dispatch(
4344
DispatchQueryRequest dispatchQueryRequest,
@@ -54,13 +55,7 @@ public DispatchQueryResponse dispatch(
5455
dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
5556
}
5657

57-
List<String> validationErrors =
58-
SQLQueryUtils.validateSparkSqlQuery(
59-
dataSourceService.getDataSource(dispatchQueryRequest.getDatasource()), query);
60-
if (!validationErrors.isEmpty()) {
61-
throw new IllegalArgumentException(
62-
"Query is not allowed: " + String.join(", ", validationErrors));
63-
}
58+
sqlQueryValidator.validate(query, dataSourceMetadata.getConnector());
6459
}
6560
return handleDefaultQuery(dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
6661
}

async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
package org.opensearch.sql.spark.utils;
77

8-
import java.util.ArrayList;
9-
import java.util.Collections;
108
import java.util.LinkedList;
119
import java.util.List;
1210
import java.util.Locale;
@@ -20,8 +18,6 @@
2018
import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream;
2119
import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener;
2220
import org.opensearch.sql.common.antlr.SyntaxCheckException;
23-
import org.opensearch.sql.datasource.model.DataSource;
24-
import org.opensearch.sql.datasource.model.DataSourceType;
2521
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor;
2622
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer;
2723
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser;
@@ -84,71 +80,12 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {
8480
}
8581
}
8682

87-
public static List<String> validateSparkSqlQuery(DataSource datasource, String sqlQuery) {
83+
public static SqlBaseParser getBaseParser(String sqlQuery) {
8884
SqlBaseParser sqlBaseParser =
8985
new SqlBaseParser(
9086
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
9187
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
92-
try {
93-
SqlBaseValidatorVisitor sqlParserBaseVisitor = getSparkSqlValidatorVisitor(datasource);
94-
StatementContext statement = sqlBaseParser.statement();
95-
sqlParserBaseVisitor.visit(statement);
96-
return sqlParserBaseVisitor.getValidationErrors();
97-
} catch (SyntaxCheckException e) {
98-
logger.error(
99-
String.format(
100-
"Failed to parse sql statement context while validating sql query %s", sqlQuery),
101-
e);
102-
return Collections.emptyList();
103-
}
104-
}
105-
106-
private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) {
107-
if (datasource != null
108-
&& datasource.getConnectorType() != null
109-
&& datasource.getConnectorType().equals(DataSourceType.SECURITY_LAKE)) {
110-
return new SparkSqlSecurityLakeValidatorVisitor();
111-
} else {
112-
return new SparkSqlValidatorVisitor();
113-
}
114-
}
115-
116-
/**
117-
* A base class extending SqlBaseParserBaseVisitor for validating Spark Sql Queries. The class
118-
* supports accumulating validation errors on visiting sql statement
119-
*/
120-
@Getter
121-
private static class SqlBaseValidatorVisitor<T> extends SqlBaseParserBaseVisitor<T> {
122-
private final List<String> validationErrors = new ArrayList<>();
123-
}
124-
125-
/** A generic validator impl for Spark Sql Queries */
126-
private static class SparkSqlValidatorVisitor extends SqlBaseValidatorVisitor<Void> {
127-
@Override
128-
public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) {
129-
getValidationErrors().add("Creating user-defined functions is not allowed");
130-
return super.visitCreateFunction(ctx);
131-
}
132-
}
133-
134-
/** A validator impl specific to Security Lake for Spark Sql Queries */
135-
private static class SparkSqlSecurityLakeValidatorVisitor extends SqlBaseValidatorVisitor<Void> {
136-
137-
public SparkSqlSecurityLakeValidatorVisitor() {
138-
// only select statement allowed. hence we add the validation error to all types of statements
139-
// by default
140-
// and remove the validation error only for select statement.
141-
getValidationErrors()
142-
.add(
143-
"Unsupported sql statement for security lake data source. Only select queries are"
144-
+ " allowed");
145-
}
146-
147-
@Override
148-
public Void visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) {
149-
getValidationErrors().clear();
150-
return super.visitStatementDefault(ctx);
151-
}
88+
return sqlBaseParser;
15289
}
15390

15491
public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.spark.validator;
7+
8+
public class DefaultGrammarElementValidator implements GrammarElementValidator {
9+
@Override
10+
public boolean isValid(GrammarElement element) {
11+
return true;
12+
}
13+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.spark.validator;
7+
8+
import java.util.Set;
9+
import lombok.RequiredArgsConstructor;
10+
11+
@RequiredArgsConstructor
12+
public class DenyListGrammarElementValidator implements GrammarElementValidator {
13+
private final Set<GrammarElement> denyList;
14+
15+
@Override
16+
public boolean isValid(GrammarElement element) {
17+
return !denyList.contains(element);
18+
}
19+
}

0 commit comments

Comments
 (0)