diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchExplainIT.java new file mode 100644 index 00000000000..136c7e3f3f9 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchExplainIT.java @@ -0,0 +1,173 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import java.io.IOException; +import org.junit.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; + +/** + * Explain-plan integration tests for vectorSearch SQL table function. These tests verify DSL + * push-down shape via _explain. They do NOT require the k-NN plugin since _explain only parses and + * plans the query without executing it against a knn index. + */ +public class VectorSearchExplainIT extends SQLIntegTestCase { + + @Override + protected void init() throws Exception { + // _explain needs the index to exist for field resolution. + loadIndex(Index.ACCOUNT); + } + + private static final String TEST_INDEX = TestsConstants.TEST_INDEX_ACCOUNT; + + // ── Top-k / radial DSL shape ───────────────────────────────────────── + + @Test + public void testExplainTopKProducesKnnQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=5') AS v " + + "LIMIT 5"); + + // WrapperQueryBuilder wraps the knn JSON — verify the wrapper is present + // and track_scores is enabled for score preservation. + assertTrue("Explain should contain wrapper query:\n" + explain, explain.contains("wrapper")); + assertTrue( + "Explain should contain track_scores:\n" + explain, explain.contains("track_scores")); + } + + @Test + public void testExplainRadialMaxDistanceProducesKnnQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='max_distance=10.5') AS v " + + "LIMIT 100"); + + assertTrue("Explain should contain wrapper query:\n" + explain, explain.contains("wrapper")); + } + + @Test + public void testExplainRadialMinScoreProducesKnnQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='min_score=0.8') AS v " + + "LIMIT 100"); + + assertTrue("Explain should contain wrapper query:\n" + explain, explain.contains("wrapper")); + } + + // ── Post-filter DSL shape ──────────────────────────────────────────── + + @Test + public void testExplainPostFilterProducesBoolQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=10') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 10"); + + assertTrue("Explain should contain bool query:\n" + explain, explain.contains("bool")); + assertTrue( + "Explain should contain must clause (knn in scoring context):\n" + explain, + explain.contains("must")); + assertTrue( + "Explain should contain filter clause (WHERE in non-scoring context):\n" + explain, + explain.contains("filter")); + } + + @Test + public void testExplainCompoundPredicateProducesBoolQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=10') AS v " + + "WHERE v.state = 'TX' AND v.age > 30 " + + "LIMIT 10"); + + assertTrue("Explain should contain bool query:\n" + explain, explain.contains("bool")); + assertTrue( + "Explain should contain must clause (knn in scoring context):\n" + explain, + explain.contains("must")); + assertTrue( + "Explain should contain filter clause (compound WHERE in non-scoring context):\n" + explain, + explain.contains("filter")); + } + + @Test + public void testExplainRadialWithWhereProducesBoolQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='max_distance=10.5') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 100"); + + assertTrue("Explain should contain bool query:\n" + explain, explain.contains("bool")); + assertTrue( + "Explain should contain must clause (knn in scoring context):\n" + explain, + explain.contains("must")); + assertTrue( + "Explain should contain filter clause (WHERE in non-scoring context):\n" + explain, + explain.contains("filter")); + } + + // ── Sort + LIMIT explain ───────────────────────────────────────────── + + @Test + public void testOrderByScoreDescExplainSucceeds() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v._score DESC " + + "LIMIT 5"); + + assertTrue( + "Explain should succeed with ORDER BY _score DESC:\n" + explain, + explain.contains("wrapper")); + } + + @Test + public void testExplainLimitWithinKSucceeds() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=10') AS v " + + "LIMIT 5"); + + assertTrue("Explain should succeed with LIMIT <= k:\n" + explain, explain.contains("wrapper")); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java new file mode 100644 index 00000000000..66c63c14e2e --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.hamcrest.Matchers.containsString; + +import java.io.IOException; +import org.junit.Test; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; + +/** + * Integration tests for vectorSearch SQL table function — validation and error paths. These tests + * verify that invalid inputs are rejected with clear error messages. Explain-plan DSL shape tests + * live in {@link VectorSearchExplainIT}. + */ +public class VectorSearchIT extends SQLIntegTestCase { + + @Override + protected void init() throws Exception { + loadIndex(Index.ACCOUNT); + } + + private static final String TEST_INDEX = TestsConstants.TEST_INDEX_ACCOUNT; + + // ── Validation error paths ──────────────────────────────────────────── + + @Test + public void testMutualExclusivityRejectsKAndMaxDistance() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,max_distance=10') AS v")); + + assertThat(ex.getMessage(), containsString("Only one of")); + } + + @Test + public void testMutualExclusivityRejectsKAndMinScore() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,min_score=0.5') AS v")); + + assertThat(ex.getMessage(), containsString("Only one of")); + } + + @Test + public void testKTooLargeRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=10001') AS v")); + + assertThat(ex.getMessage(), containsString("k must be between 1 and 10000")); + } + + @Test + public void testKZeroRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=0') AS v")); + + assertThat(ex.getMessage(), containsString("k must be between 1 and 10000")); + } + + @Test + public void testUnknownOptionKeyRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,method.ef_search=100') AS v")); + + assertThat(ex.getMessage(), containsString("Unknown option key")); + } + + @Test + public void testEmptyVectorRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("must not be empty")); + } + + @Test + public void testInvalidFieldNameRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', " + + "field='field\\\"injection', vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Invalid field name")); + } + + @Test + public void testMissingRequiredOptionRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='') AS v")); + + assertThat(ex.getMessage(), containsString("Missing required option")); + } + + // ── Sort restriction validation ───────────────────────────────────────── + + @Test + public void testOrderByNonScoreFieldRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v.firstname ASC " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("unsupported sort expression")); + } + + @Test + public void testOrderByScoreAscRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v._score ASC " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("_score ASC is not supported")); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java index f33d5f2fa73..fe79aea5321 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java @@ -48,14 +48,17 @@ public TableScanBuilder createScanBuilder() { // Use VectorSearchQueryBuilder to keep knn in must (scoring) context. // WHERE filters will be placed in filter (non-scoring) context. - var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery()); + var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery(), options); requestBuilder.pushDownTrackedScore(true); - // Top-k mode: default size to k so queries without LIMIT return k results - // instead of falling into the generic large-scan path. - // LIMIT pushdown will further reduce this if present. + // Default size policy: LIMIT pushdown will further reduce if present. if (options.containsKey("k")) { + // Top-k mode: default size to k so queries without LIMIT return k results. requestBuilder.pushDownLimitToRequestTotal(Integer.parseInt(options.get("k")), 0); + } else { + // Radial mode (max_distance/min_score): cap at maxResultWindow. + // Without an explicit cap, radial queries could return unbounded results. + requestBuilder.pushDownLimitToRequestTotal(getMaxResultWindow(), 0); } Function createScanOperator = @@ -68,6 +71,11 @@ public TableScanBuilder createScanBuilder() { } private QueryBuilder buildKnnQuery() { + return new WrapperQueryBuilder(buildKnnQueryJson()); + } + + // Package-private for testing + String buildKnnQueryJson() { StringBuilder vectorJson = new StringBuilder("["); for (int i = 0; i < vector.length; i++) { if (i > 0) vectorJson.append(","); @@ -88,11 +96,9 @@ private QueryBuilder buildKnnQuery() { } } - String knnQueryJson = - String.format( - "{\"knn\":{\"%s\":{\"vector\":%s%s}}}", - field, vectorJson.toString(), optionsJson.toString()); - return new WrapperQueryBuilder(knnQueryJson); + return String.format( + "{\"knn\":{\"%s\":{\"vector\":%s%s}}}", + field, vectorJson.toString(), optionsJson.toString()); } private static boolean isNumeric(String str) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java index c4c383f9623..7f01a973970 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java @@ -197,9 +197,20 @@ private void validateOptions(Map options) { throw new ExpressionEvaluationException( "Missing required option: one of k, max_distance, or min_score"); } + // Mutual exclusivity: exactly one search mode allowed + int modeCount = (hasK ? 1 : 0) + (hasMaxDistance ? 1 : 0) + (hasMinScore ? 1 : 0); + if (modeCount > 1) { + throw new ExpressionEvaluationException( + "Only one of k, max_distance, or min_score may be specified"); + } // Parse and canonicalize numeric values — closes JSON injection via option values if (hasK) { parseIntOption(options, "k"); + int k = Integer.parseInt(options.get("k")); + if (k < 1 || k > 10000) { + throw new ExpressionEvaluationException( + String.format("k must be between 1 and 10000, got %d", k)); + } } if (hasMaxDistance) { parseDoubleOption(options, "max_distance"); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java index efc2f333b0d..ca4df9629a9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java @@ -5,14 +5,22 @@ package org.opensearch.sql.opensearch.storage.scan; +import java.util.Map; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.serde.DefaultExpressionSerializer; import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalSort; /** * Query builder for vector search that keeps the knn query in a scoring (must) context and puts @@ -25,12 +33,14 @@ public class VectorSearchQueryBuilder extends OpenSearchIndexScanQueryBuilder { private final QueryBuilder knnQuery; + private final Map options; - public VectorSearchQueryBuilder(OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery) { + public VectorSearchQueryBuilder( + OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery, Map options) { super(requestBuilder); - // Set knn as the initial query (scoring context) requestBuilder.getSourceBuilder().query(knnQuery); this.knnQuery = knnQuery; + this.options = options; } @Override @@ -44,4 +54,50 @@ public boolean pushDownFilter(LogicalFilter filter) { requestBuilder.getSourceBuilder().query(combined); return true; } + + @Override + public boolean pushDownLimit(LogicalLimit limit) { + validateLimitWithinK(limit.getLimit()); + return super.pushDownLimit(limit); + } + + @Override + public boolean pushDownSort(LogicalSort sort) { + // Vector search returns results sorted by _score DESC by default. + // Only _score DESC is meaningful; reject all other sort expressions. + for (Pair sortItem : sort.getSortList()) { + Expression expr = sortItem.getRight(); + if (!(expr instanceof ReferenceExpression) + || !"_score".equals(((ReferenceExpression) expr).getAttr())) { + throw new ExpressionEvaluationException( + String.format( + "vectorSearch only supports ORDER BY _score DESC; " + + "unsupported sort expression: %s", + expr)); + } + if (sortItem.getLeft().getSortOrder() != Sort.SortOrder.DESC) { + throw new ExpressionEvaluationException( + "vectorSearch only supports ORDER BY _score DESC; _score ASC is not supported"); + } + } + // _score DESC is the natural knn order — no need to push the sort itself to OpenSearch. + // Preserve the parent's sort.getCount() → limit pushdown contract: SQL always sets count=0, + // but PPL or future callers may set a non-zero count to combine sort+limit in one node. + if (sort.getCount() != 0) { + validateLimitWithinK(sort.getCount()); + requestBuilder.pushDownLimit(sort.getCount(), 0); + } + return true; + } + + /** Validates that the requested limit does not exceed k in top-k mode. */ + private void validateLimitWithinK(int limit) { + if (options.containsKey("k")) { + int k = Integer.parseInt(options.get("k")); + if (limit > k) { + throw new ExpressionEvaluationException( + String.format("LIMIT %d exceeds k=%d in top-k vector search", limit, k)); + } + } + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchIndexTest.java new file mode 100644 index 00000000000..2c90193847e --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchIndexTest.java @@ -0,0 +1,149 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.LinkedHashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.opensearch.client.OpenSearchClient; + +@ExtendWith(MockitoExtension.class) +class VectorSearchIndexTest { + + @Mock private OpenSearchClient client; + + @Mock private Settings settings; + + @Test + void buildKnnQueryJsonTopK() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, 2.0f, 3.0f}, + Map.of("k", "5")); + + String json = index.buildKnnQueryJson(); + assertEquals("{\"knn\":{\"embedding\":{\"vector\":[1.0,2.0,3.0],\"k\":5}}}", json); + } + + @Test + void buildKnnQueryJsonRadialMaxDistance() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, 2.0f}, + Map.of("max_distance", "10.5")); + + String json = index.buildKnnQueryJson(); + assertEquals("{\"knn\":{\"embedding\":{\"vector\":[1.0,2.0],\"max_distance\":10.5}}}", json); + } + + @Test + void buildKnnQueryJsonRadialMinScore() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {0.5f}, + Map.of("min_score", "0.8")); + + String json = index.buildKnnQueryJson(); + assertEquals("{\"knn\":{\"embedding\":{\"vector\":[0.5],\"min_score\":0.8}}}", json); + } + + @Test + void buildKnnQueryJsonNestedFieldName() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "doc.embedding", + new float[] {1.0f, 2.0f}, + Map.of("k", "10")); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("\"doc.embedding\""), "Should contain nested field name with dot"); + } + + @Test + void buildKnnQueryJsonMultiElementVector() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, -2.5f, 0.0f, 3.14f, 100.0f}, + Map.of("k", "3")); + + String json = index.buildKnnQueryJson(); + assertTrue( + json.contains("[1.0,-2.5,0.0,3.14,100.0]"), + "Should contain all vector components with correct comma separation"); + } + + @Test + void buildKnnQueryJsonSingleElementVector() { + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {42.0f}, Map.of("k", "1")); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("[42.0]"), "Should contain single-element vector"); + } + + @Test + void buildKnnQueryJsonNumericOptionRenderedUnquoted() { + LinkedHashMap options = new LinkedHashMap<>(); + options.put("k", "5"); + + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {1.0f}, options); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("\"k\":5"), "Numeric option should be unquoted"); + } + + @Test + void buildKnnQueryJsonNonNumericOptionRenderedQuoted() { + LinkedHashMap options = new LinkedHashMap<>(); + options.put("k", "5"); + options.put("method", "hnsw"); + + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {1.0f}, options); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("\"method\":\"hnsw\""), "Non-numeric option should be quoted"); + assertTrue(json.contains("\"k\":5"), "Numeric option should be unquoted"); + } + + @Test + void isInstanceOfOpenSearchIndex() { + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {1.0f}, Map.of("k", "5")); + assertTrue(index instanceof OpenSearchIndex); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java index 71f0bfa80af..ec8e5161d8c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -233,6 +233,65 @@ void testInfiniteMinScoreThrows() { assertTrue(ex.getMessage().contains("must be a finite number")); } + @Test + void testMutualExclusivityKAndMaxDistanceThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,max_distance=10.0"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Only one of")); + } + + @Test + void testMutualExclusivityKAndMinScoreThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,min_score=0.5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Only one of")); + } + + @Test + void testMutualExclusivityAllThreeThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs( + "my-index", "embedding", "[1.0, 2.0]", "k=5,max_distance=10.0,min_score=0.5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Only one of")); + } + + @Test + void testKTooSmallThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=0"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("k must be between 1 and 10000")); + } + + @Test + void testKTooLargeThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10001"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("k must be between 1 and 10000")); + } + + @Test + void testKBoundaryValuesAllowed() { + // k=1 should work + VectorSearchTableFunctionImplementation impl1 = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=1"); + assertTrue(impl1.applyArguments() instanceof VectorSearchIndex); + + // k=10000 should work + VectorSearchTableFunctionImplementation impl2 = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10000"); + assertTrue(impl2.applyArguments() instanceof VectorSearchIndex); + } + @Test void testNonNamedArgThrows() { FunctionName functionName = FunctionName.of("vectorsearch"); @@ -260,6 +319,75 @@ void testNullArgNameThrows() { assertTrue(ex.getMessage().contains("requires named arguments")); } + @Test + void testNaNVectorComponentThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, NaN, 3.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testEmptyOptionKeyThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("=value")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void testEmptyOptionValueThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("key=")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void testNegativeKThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=-1"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("k must be between 1 and 10000")); + } + + @Test + void testNaNMaxDistanceThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=NaN"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testNaNMinScoreThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=NaN"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testCaseInsensitiveArgLookup() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("TABLE", DSL.literal("my-index")), + DSL.namedArgument("FIELD", DSL.literal("embedding")), + DSL.namedArgument("VECTOR", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("OPTION", DSL.literal("k=5"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + private VectorSearchTableFunctionImplementation createImpl() { return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5"); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java index 77efd0a6d88..4816dd17fdb 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java @@ -83,4 +83,48 @@ void testWrongArgumentCount() { IllegalArgumentException.class, () -> builder.apply(functionProperties, expressions)); assertTrue(ex.getMessage().contains("requires 4 arguments")); } + + @Test + void testTooManyArguments() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0]")), + DSL.namedArgument("option", DSL.literal("k=5")), + DSL.namedArgument("extra", DSL.literal("unexpected"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + FunctionBuilder builder = resolution.getValue(); + + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("requires 4 arguments")); + } + + @Test + void testZeroArguments() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = List.of(); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + FunctionBuilder builder = resolution.getValue(); + + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("requires 4 arguments")); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java index 2df785b41a0..bda51872448 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java @@ -6,21 +6,27 @@ package org.opensearch.sql.opensearch.storage.scan; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import java.util.Collections; +import java.util.List; +import java.util.Map; import org.junit.jupiter.api.Test; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.WrapperQueryBuilder; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalValues; class VectorSearchQueryBuilderTest { @@ -30,7 +36,7 @@ void knnQuerySetAsScoringQuery() { var requestBuilder = createRequestBuilder(); var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); - new VectorSearchQueryBuilder(requestBuilder, knnQuery); + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); QueryBuilder query = requestBuilder.getSourceBuilder().query(); assertTrue( @@ -42,7 +48,7 @@ void knnQuerySetAsScoringQuery() { void pushDownFilterKeepsKnnInScoringContext() { var requestBuilder = createRequestBuilder(); var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); - var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); // Simulate WHERE name = 'John' var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")); @@ -62,6 +68,227 @@ void pushDownFilterKeepsKnnInScoringContext() { "must clause should contain the original knn WrapperQueryBuilder"); } + @Test + void pushDownLimitWithinKSucceeds() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 3, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "LIMIT within k should succeed"); + } + + @Test + void pushDownLimitExceedingKThrows() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 10, 0); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownLimit(limit)); + assertTrue(ex.getMessage().contains("LIMIT 10 exceeds k=5")); + } + + @Test + void pushDownLimitEqualToKSucceeds() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 5, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "LIMIT equal to k should succeed"); + } + + @Test + void pushDownLimitRadialModeNoRestriction() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("max_distance", "10.0")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 100, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "Radial mode should not restrict LIMIT"); + } + + @Test + void pushDownLimitMinScoreModeNoRestriction() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("min_score", "0.5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 100, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "min_score mode should not restrict LIMIT"); + } + + @Test + void pushDownSortScoreDescAccepted() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + boolean pushed = builder.pushDownSort(sort); + assertTrue(pushed, "ORDER BY _score DESC should be accepted"); + } + + @Test + void pushDownSortPreservesSortCountAsLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "10")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + // LogicalSort with count=7 simulates a sort+limit combined node (PPL path) + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + 7, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + boolean pushed = builder.pushDownSort(sort); + assertTrue(pushed, "ORDER BY _score DESC with count should be accepted"); + assertEquals( + 7, + requestBuilder.getMaxResponseSize(), + "sort.getCount() should be pushed down as request size"); + } + + @Test + void pushDownSortCountExceedingKRejects() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + // LogicalSort with count=10 exceeds k=5 — should be rejected + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + 10, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("LIMIT 10 exceeds k=5")); + } + + @Test + void pushDownSortNonScoreFieldRejected() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("name", STRING)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("unsupported sort expression")); + } + + @Test + void pushDownSortMultipleExpressionsRejectsNonScore() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)), + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("name", STRING)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("unsupported sort expression")); + } + + @Test + void pushDownSortScoreAscRejected() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("_score ASC is not supported")); + } + + @Test + void pushDownFilterCompoundPredicateSurvives() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + // Simulate WHERE name = 'John' AND age > 30 + var condition = + DSL.and( + DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")), + DSL.greater(new ReferenceExpression("age", ExprCoreType.INTEGER), DSL.literal(30))); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + boolean pushed = builder.pushDownFilter(filter); + + assertTrue(pushed, "pushDownFilter with compound predicate should succeed"); + QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query(); + assertTrue(resultQuery instanceof BoolQueryBuilder, "Result should be a BoolQuery"); + BoolQueryBuilder boolQuery = (BoolQueryBuilder) resultQuery; + assertEquals(1, boolQuery.must().size(), "knn query should be in must (scoring context)"); + assertEquals(1, boolQuery.filter().size(), "compound WHERE should be in filter (non-scoring)"); + } + private OpenSearchRequestBuilder createRequestBuilder() { return new OpenSearchRequestBuilder( mock(OpenSearchExprValueFactory.class), 10000, mock(Settings.class));