Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.sql;

import java.io.IOException;
import org.junit.Test;
import org.opensearch.sql.legacy.SQLIntegTestCase;
import org.opensearch.sql.legacy.TestsConstants;

/**
* Explain-plan integration tests for vectorSearch SQL table function. These tests verify DSL
* push-down shape via _explain. They do NOT require the k-NN plugin since _explain only parses and
* plans the query without executing it against a knn index.
*/
public class VectorSearchExplainIT extends SQLIntegTestCase {

@Override
protected void init() throws Exception {
// _explain needs the index to exist for field resolution.
loadIndex(Index.ACCOUNT);
}

private static final String TEST_INDEX = TestsConstants.TEST_INDEX_ACCOUNT;

// ── Top-k / radial DSL shape ─────────────────────────────────────────

@Test
public void testExplainTopKProducesKnnQuery() throws IOException {
String explain =
explainQuery(
"SELECT v._id, v._score "
+ "FROM vectorSearch(table='"
+ TEST_INDEX
+ "', field='embedding', "
+ "vector='[1.0, 2.0, 3.0]', option='k=5') AS v "
+ "LIMIT 5");

// WrapperQueryBuilder wraps the knn JSON — verify the wrapper is present
// and track_scores is enabled for score preservation.
assertTrue("Explain should contain wrapper query:\n" + explain, explain.contains("wrapper"));
assertTrue(
"Explain should contain track_scores:\n" + explain, explain.contains("track_scores"));
}

@Test
public void testExplainRadialMaxDistanceProducesKnnQuery() throws IOException {
String explain =
explainQuery(
"SELECT v._id, v._score "
+ "FROM vectorSearch(table='"
+ TEST_INDEX
+ "', field='embedding', "
+ "vector='[1.0, 2.0]', option='max_distance=10.5') AS v "
+ "LIMIT 100");

assertTrue("Explain should contain wrapper query:\n" + explain, explain.contains("wrapper"));
}

@Test
public void testExplainRadialMinScoreProducesKnnQuery() throws IOException {
String explain =
explainQuery(
"SELECT v._id, v._score "
+ "FROM vectorSearch(table='"
+ TEST_INDEX
+ "', field='embedding', "
+ "vector='[1.0, 2.0]', option='min_score=0.8') AS v "
+ "LIMIT 100");

assertTrue("Explain should contain wrapper query:\n" + explain, explain.contains("wrapper"));
}

// ── Post-filter DSL shape ────────────────────────────────────────────

@Test
public void testExplainPostFilterProducesBoolQuery() throws IOException {
String explain =
explainQuery(
"SELECT v._id, v._score "
+ "FROM vectorSearch(table='"
+ TEST_INDEX
+ "', field='embedding', "
+ "vector='[1.0, 2.0, 3.0]', option='k=10') AS v "
+ "WHERE v.state = 'TX' "
+ "LIMIT 10");

assertTrue("Explain should contain bool query:\n" + explain, explain.contains("bool"));
assertTrue(
"Explain should contain must clause (knn in scoring context):\n" + explain,
explain.contains("must"));
assertTrue(
"Explain should contain filter clause (WHERE in non-scoring context):\n" + explain,
explain.contains("filter"));
}

@Test
public void testExplainCompoundPredicateProducesBoolQuery() throws IOException {
String explain =
explainQuery(
"SELECT v._id, v._score "
+ "FROM vectorSearch(table='"
+ TEST_INDEX
+ "', field='embedding', "
+ "vector='[1.0, 2.0, 3.0]', option='k=10') AS v "
+ "WHERE v.state = 'TX' AND v.age > 30 "
+ "LIMIT 10");

assertTrue("Explain should contain bool query:\n" + explain, explain.contains("bool"));
assertTrue(
"Explain should contain must clause (knn in scoring context):\n" + explain,
explain.contains("must"));
assertTrue(
"Explain should contain filter clause (compound WHERE in non-scoring context):\n" + explain,
explain.contains("filter"));
}

@Test
public void testExplainRadialWithWhereProducesBoolQuery() throws IOException {
String explain =
explainQuery(
"SELECT v._id, v._score "
+ "FROM vectorSearch(table='"
+ TEST_INDEX
+ "', field='embedding', "
+ "vector='[1.0, 2.0]', option='max_distance=10.5') AS v "
+ "WHERE v.state = 'TX' "
+ "LIMIT 100");

assertTrue("Explain should contain bool query:\n" + explain, explain.contains("bool"));
assertTrue(
"Explain should contain must clause (knn in scoring context):\n" + explain,
explain.contains("must"));
assertTrue(
"Explain should contain filter clause (WHERE in non-scoring context):\n" + explain,
explain.contains("filter"));
}

// ── Sort + LIMIT explain ─────────────────────────────────────────────

@Test
public void testOrderByScoreDescExplainSucceeds() throws IOException {
String explain =
explainQuery(
"SELECT v._id, v._score "
+ "FROM vectorSearch(table='"
+ TEST_INDEX
+ "', field='embedding', "
+ "vector='[1.0, 2.0]', option='k=5') AS v "
+ "ORDER BY v._score DESC "
+ "LIMIT 5");

assertTrue(
"Explain should succeed with ORDER BY _score DESC:\n" + explain,
explain.contains("wrapper"));
}

@Test
public void testExplainLimitWithinKSucceeds() throws IOException {
String explain =
explainQuery(
"SELECT v._id, v._score "
+ "FROM vectorSearch(table='"
+ TEST_INDEX
+ "', field='embedding', "
+ "vector='[1.0, 2.0]', option='k=10') AS v "
+ "LIMIT 5");

assertTrue("Explain should succeed with LIMIT <= k:\n" + explain, explain.contains("wrapper"));
}
}
171 changes: 171 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.sql;

import static org.hamcrest.Matchers.containsString;

import java.io.IOException;
import org.junit.Test;
import org.opensearch.client.ResponseException;
import org.opensearch.sql.legacy.SQLIntegTestCase;
import org.opensearch.sql.legacy.TestsConstants;

/**
* Integration tests for vectorSearch SQL table function — validation and error paths. These tests
* verify that invalid inputs are rejected with clear error messages. Explain-plan DSL shape tests
* live in {@link VectorSearchExplainIT}.
*/
public class VectorSearchIT extends SQLIntegTestCase {

@Override
protected void init() throws Exception {
loadIndex(Index.ACCOUNT);
}

private static final String TEST_INDEX = TestsConstants.TEST_INDEX_ACCOUNT;

// ── Validation error paths ────────────────────────────────────────────

@Test
public void testMutualExclusivityRejectsKAndMaxDistance() throws IOException {
ResponseException ex =
expectThrows(
ResponseException.class,
() ->
executeQuery(
"SELECT v._id FROM vectorSearch(table='t', field='f', "
+ "vector='[1.0]', option='k=5,max_distance=10') AS v"));

assertThat(ex.getMessage(), containsString("Only one of"));
}

@Test
public void testMutualExclusivityRejectsKAndMinScore() throws IOException {
ResponseException ex =
expectThrows(
ResponseException.class,
() ->
executeQuery(
"SELECT v._id FROM vectorSearch(table='t', field='f', "
+ "vector='[1.0]', option='k=5,min_score=0.5') AS v"));

assertThat(ex.getMessage(), containsString("Only one of"));
}

@Test
public void testKTooLargeRejects() throws IOException {
ResponseException ex =
expectThrows(
ResponseException.class,
() ->
executeQuery(
"SELECT v._id FROM vectorSearch(table='t', field='f', "
+ "vector='[1.0]', option='k=10001') AS v"));

assertThat(ex.getMessage(), containsString("k must be between 1 and 10000"));
}

@Test
public void testKZeroRejects() throws IOException {
ResponseException ex =
expectThrows(
ResponseException.class,
() ->
executeQuery(
"SELECT v._id FROM vectorSearch(table='t', field='f', "
+ "vector='[1.0]', option='k=0') AS v"));

assertThat(ex.getMessage(), containsString("k must be between 1 and 10000"));
}

@Test
public void testUnknownOptionKeyRejects() throws IOException {
ResponseException ex =
expectThrows(
ResponseException.class,
() ->
executeQuery(
"SELECT v._id FROM vectorSearch(table='t', field='f', "
+ "vector='[1.0]', option='k=5,method.ef_search=100') AS v"));

assertThat(ex.getMessage(), containsString("Unknown option key"));
}

@Test
public void testEmptyVectorRejects() throws IOException {
ResponseException ex =
expectThrows(
ResponseException.class,
() ->
executeQuery(
"SELECT v._id FROM vectorSearch(table='t', field='f', "
+ "vector='[]', option='k=5') AS v"));

assertThat(ex.getMessage(), containsString("must not be empty"));
}

@Test
public void testInvalidFieldNameRejects() throws IOException {
ResponseException ex =
expectThrows(
ResponseException.class,
() ->
executeQuery(
"SELECT v._id FROM vectorSearch(table='t', "
+ "field='field\\\"injection', vector='[1.0]', option='k=5') AS v"));

assertThat(ex.getMessage(), containsString("Invalid field name"));
}

@Test
public void testMissingRequiredOptionRejects() throws IOException {
ResponseException ex =
expectThrows(
ResponseException.class,
() ->
executeQuery(
"SELECT v._id FROM vectorSearch(table='t', field='f', "
+ "vector='[1.0]', option='') AS v"));

assertThat(ex.getMessage(), containsString("Missing required option"));
}

// ── Sort restriction validation ─────────────────────────────────────────

@Test
public void testOrderByNonScoreFieldRejects() throws IOException {
ResponseException ex =
expectThrows(
ResponseException.class,
() ->
executeQuery(
"SELECT v._id FROM vectorSearch(table='"
+ TEST_INDEX
+ "', field='embedding', "
+ "vector='[1.0, 2.0]', option='k=5') AS v "
+ "ORDER BY v.firstname ASC "
+ "LIMIT 5"));

assertThat(ex.getMessage(), containsString("unsupported sort expression"));
}

@Test
public void testOrderByScoreAscRejects() throws IOException {
ResponseException ex =
expectThrows(
ResponseException.class,
() ->
executeQuery(
"SELECT v._id FROM vectorSearch(table='"
+ TEST_INDEX
+ "', field='embedding', "
+ "vector='[1.0, 2.0]', option='k=5') AS v "
+ "ORDER BY v._score ASC "
+ "LIMIT 5"));

assertThat(ex.getMessage(), containsString("_score ASC is not supported"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,17 @@ public TableScanBuilder createScanBuilder() {

// Use VectorSearchQueryBuilder to keep knn in must (scoring) context.
// WHERE filters will be placed in filter (non-scoring) context.
var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery());
var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery(), options);
requestBuilder.pushDownTrackedScore(true);

// Top-k mode: default size to k so queries without LIMIT return k results
// instead of falling into the generic large-scan path.
// LIMIT pushdown will further reduce this if present.
// Default size policy: LIMIT pushdown will further reduce if present.
if (options.containsKey("k")) {
// Top-k mode: default size to k so queries without LIMIT return k results.
requestBuilder.pushDownLimitToRequestTotal(Integer.parseInt(options.get("k")), 0);
} else {
// Radial mode (max_distance/min_score): cap at maxResultWindow.
// Without an explicit cap, radial queries could return unbounded results.
requestBuilder.pushDownLimitToRequestTotal(getMaxResultWindow(), 0);
}

Function<OpenSearchRequestBuilder, OpenSearchIndexScan> createScanOperator =
Expand All @@ -68,6 +71,11 @@ public TableScanBuilder createScanBuilder() {
}

private QueryBuilder buildKnnQuery() {
return new WrapperQueryBuilder(buildKnnQueryJson());
}

// Package-private for testing
String buildKnnQueryJson() {
StringBuilder vectorJson = new StringBuilder("[");
for (int i = 0; i < vector.length; i++) {
if (i > 0) vectorJson.append(",");
Expand All @@ -88,11 +96,9 @@ private QueryBuilder buildKnnQuery() {
}
}

String knnQueryJson =
String.format(
"{\"knn\":{\"%s\":{\"vector\":%s%s}}}",
field, vectorJson.toString(), optionsJson.toString());
return new WrapperQueryBuilder(knnQueryJson);
return String.format(
"{\"knn\":{\"%s\":{\"vector\":%s%s}}}",
field, vectorJson.toString(), optionsJson.toString());
}

private static boolean isNumeric(String str) {
Expand Down
Loading
Loading