diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java index 837a2a062ef..79d49a143de 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java @@ -43,7 +43,8 @@ public enum MappingType { ScaledFloat("scaled_float", ExprCoreType.DOUBLE), Double("double", ExprCoreType.DOUBLE), Boolean("boolean", ExprCoreType.BOOLEAN), - Alias("alias", ExprCoreType.UNKNOWN); + Alias("alias", ExprCoreType.UNKNOWN), + KnnVector("knn_vector", ExprCoreType.ARRAY); // TODO: ranges, geo shape, point, shape private final String name; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java index ce6740cd784..1b7de315fb6 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java @@ -7,10 +7,13 @@ import static org.opensearch.sql.utils.SystemIndexUtils.isSystemIndex; +import java.util.Collection; +import java.util.List; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.storage.system.OpenSearchSystemIndex; import org.opensearch.sql.storage.StorageEngine; @@ -25,6 +28,11 @@ public class OpenSearchStorageEngine implements StorageEngine { @Getter private final Settings settings; + @Override + public Collection getFunctions() { + return List.of(new VectorSearchTableFunctionResolver(client, settings)); + } + @Override public Table getTable(DataSourceSchemaName dataSourceSchemaName, String name) { if (isSystemIndex(name)) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java new file mode 100644 index 00000000000..f33d5f2fa73 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java @@ -0,0 +1,106 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import java.util.Map; +import java.util.function.Function; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanBuilder; +import org.opensearch.sql.opensearch.storage.scan.VectorSearchQueryBuilder; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Vector-search-aware OpenSearch index. Seeds the scan with a knn query and enables score tracking. + */ +public class VectorSearchIndex extends OpenSearchIndex { + + private final String field; + private final float[] vector; + private final Map options; + + public VectorSearchIndex( + OpenSearchClient client, + Settings settings, + String indexName, + String field, + float[] vector, + Map options) { + super(client, settings, indexName); + this.field = field; + this.vector = vector; + this.options = options; + } + + @Override + public TableScanBuilder createScanBuilder() { + final TimeValue cursorKeepAlive = + getSettings().getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + var requestBuilder = createRequestBuilder(); + + // Use VectorSearchQueryBuilder to keep knn in must (scoring) context. + // WHERE filters will be placed in filter (non-scoring) context. + var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery()); + requestBuilder.pushDownTrackedScore(true); + + // Top-k mode: default size to k so queries without LIMIT return k results + // instead of falling into the generic large-scan path. + // LIMIT pushdown will further reduce this if present. + if (options.containsKey("k")) { + requestBuilder.pushDownLimitToRequestTotal(Integer.parseInt(options.get("k")), 0); + } + + Function createScanOperator = + rb -> + new OpenSearchIndexScan( + getClient(), + rb.getMaxResponseSize(), + rb.build(getIndexName(), cursorKeepAlive, getClient(), getFieldTypes().isEmpty())); + return new OpenSearchIndexScanBuilder(queryBuilder, createScanOperator); + } + + private QueryBuilder buildKnnQuery() { + StringBuilder vectorJson = new StringBuilder("["); + for (int i = 0; i < vector.length; i++) { + if (i > 0) vectorJson.append(","); + vectorJson.append(vector[i]); + } + vectorJson.append("]"); + + StringBuilder optionsJson = new StringBuilder(); + for (Map.Entry entry : options.entrySet()) { + optionsJson.append(","); + String value = entry.getValue(); + // All P0 option values are canonicalized to numeric strings by validateOptions(). + // The quoted fallback is retained for forward compatibility with future non-numeric options. + if (isNumeric(value)) { + optionsJson.append(String.format("\"%s\":%s", entry.getKey(), value)); + } else { + optionsJson.append(String.format("\"%s\":\"%s\"", entry.getKey(), value)); + } + } + + String knnQueryJson = + String.format( + "{\"knn\":{\"%s\":{\"vector\":%s%s}}}", + field, vectorJson.toString(), optionsJson.toString()); + return new WrapperQueryBuilder(knnQueryJson); + } + + private static boolean isNumeric(String str) { + try { + Double.parseDouble(str); + return true; + } catch (NumberFormatException e) { + return false; + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java new file mode 100644 index 00000000000..c4c383f9623 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java @@ -0,0 +1,246 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.FIELD; +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.OPTION; +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.TABLE; +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.VECTOR; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.storage.Table; + +public class VectorSearchTableFunctionImplementation extends FunctionExpression + implements TableFunctionImplementation { + + /** P0 allowed option keys. Rejects unknown/future keys to prevent unvalidated DSL injection. */ + static final Set ALLOWED_OPTION_KEYS = Set.of("k", "max_distance", "min_score"); + + /** + * Field names must be safe for JSON interpolation: alphanumeric, dots (nested), underscores, + * hyphens. Rejects characters that could corrupt the WrapperQueryBuilder JSON. + */ + private static final Pattern SAFE_FIELD_NAME = Pattern.compile("^[a-zA-Z0-9._\\-]+$"); + + private final FunctionName functionName; + private final List arguments; + private final OpenSearchClient client; + private final Settings settings; + + public VectorSearchTableFunctionImplementation( + FunctionName functionName, + List arguments, + OpenSearchClient client, + Settings settings) { + super(functionName, arguments); + this.functionName = functionName; + this.arguments = arguments; + this.client = client; + this.settings = settings; + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new UnsupportedOperationException( + String.format("vectorSearch function [%s] is only supported in FROM clause", functionName)); + } + + @Override + public ExprType type() { + return ExprCoreType.STRUCT; + } + + @Override + public String toString() { + List args = + arguments.stream() + .map( + arg -> { + if (arg instanceof NamedArgumentExpression) { + NamedArgumentExpression named = (NamedArgumentExpression) arg; + return String.format("%s=%s", named.getArgName(), named.getValue().toString()); + } + return arg.toString(); + }) + .collect(Collectors.toList()); + return String.format("%s(%s)", functionName, String.join(", ", args)); + } + + @Override + public Table applyArguments() { + validateNamedArgs(); + String tableName = getArgumentValue(TABLE); + String fieldName = getArgumentValue(FIELD); + validateFieldName(fieldName); + String vectorLiteral = getArgumentValue(VECTOR); + String optionStr = getArgumentValue(OPTION); + + float[] vector = parseVector(vectorLiteral); + Map options = parseOptions(optionStr); + validateOptions(options); + + return new VectorSearchIndex(client, settings, tableName, fieldName, vector, options); + } + + private float[] parseVector(String vectorLiteral) { + String cleaned = vectorLiteral.replaceAll("[\\[\\]]", "").trim(); + if (cleaned.isEmpty()) { + throw new ExpressionEvaluationException("Vector literal must not be empty"); + } + String[] parts = cleaned.split(","); + float[] vector = new float[parts.length]; + for (int i = 0; i < parts.length; i++) { + try { + vector[i] = Float.parseFloat(parts[i].trim()); + } catch (NumberFormatException e) { + throw new ExpressionEvaluationException( + String.format("Invalid vector component '%s': must be a number", parts[i].trim())); + } + if (!Float.isFinite(vector[i])) { + throw new ExpressionEvaluationException( + String.format( + "Invalid vector component '%s': must be a finite number", parts[i].trim())); + } + } + return vector; + } + + static Map parseOptions(String optionStr) { + Map options = new LinkedHashMap<>(); + for (String pair : optionStr.split(",")) { + String trimmed = pair.trim(); + if (trimmed.isEmpty()) { + continue; + } + String[] kv = trimmed.split("=", 2); + if (kv.length != 2 || kv[0].trim().isEmpty() || kv[1].trim().isEmpty()) { + throw new ExpressionEvaluationException( + String.format("Malformed option segment '%s': expected key=value", trimmed)); + } + String key = kv[0].trim(); + if (options.containsKey(key)) { + throw new ExpressionEvaluationException(String.format("Duplicate option key '%s'", key)); + } + options.put(key, kv[1].trim()); + } + return options; + } + + /** Reject non-named arguments and null arg names early. */ + private void validateNamedArgs() { + for (Expression arg : arguments) { + if (!(arg instanceof NamedArgumentExpression)) { + throw new ExpressionEvaluationException( + "vectorSearch() requires named arguments (e.g., table='index'), " + + "but received: " + + arg.getClass().getSimpleName()); + } + String name = ((NamedArgumentExpression) arg).getArgName(); + if (name == null || name.isEmpty()) { + throw new ExpressionEvaluationException( + "vectorSearch() requires named arguments (e.g., table='index'), " + + "but received an argument with no name"); + } + } + } + + /** + * Reject field names with characters that could corrupt the WrapperQueryBuilder JSON. Allows + * alphanumeric, dots (nested fields), underscores, and hyphens. + */ + private void validateFieldName(String fieldName) { + if (!SAFE_FIELD_NAME.matcher(fieldName).matches()) { + throw new ExpressionEvaluationException( + String.format( + "Invalid field name '%s': must contain only alphanumeric characters," + + " dots, underscores, or hyphens", + fieldName)); + } + } + + /** + * Validates and canonicalizes option values. All P0 option values must be numeric. Parsing them + * here prevents non-numeric strings from reaching the raw JSON construction in buildKnnQuery(). + */ + private void validateOptions(Map options) { + // Reject unknown option keys — only P0 keys are allowed + for (String key : options.keySet()) { + if (!ALLOWED_OPTION_KEYS.contains(key)) { + throw new ExpressionEvaluationException( + String.format("Unknown option key '%s'. Supported keys: %s", key, ALLOWED_OPTION_KEYS)); + } + } + boolean hasK = options.containsKey("k"); + boolean hasMaxDistance = options.containsKey("max_distance"); + boolean hasMinScore = options.containsKey("min_score"); + if (!hasK && !hasMaxDistance && !hasMinScore) { + throw new ExpressionEvaluationException( + "Missing required option: one of k, max_distance, or min_score"); + } + // Parse and canonicalize numeric values — closes JSON injection via option values + if (hasK) { + parseIntOption(options, "k"); + } + if (hasMaxDistance) { + parseDoubleOption(options, "max_distance"); + } + if (hasMinScore) { + parseDoubleOption(options, "min_score"); + } + } + + private void parseIntOption(Map options, String key) { + try { + int value = Integer.parseInt(options.get(key)); + options.put(key, Integer.toString(value)); + } catch (NumberFormatException e) { + throw new ExpressionEvaluationException( + String.format("Option '%s' must be an integer, got '%s'", key, options.get(key))); + } + } + + private void parseDoubleOption(Map options, String key) { + try { + double value = Double.parseDouble(options.get(key)); + if (!Double.isFinite(value)) { + throw new ExpressionEvaluationException( + String.format("Option '%s' must be a finite number, got '%s'", key, options.get(key))); + } + options.put(key, Double.toString(value)); + } catch (NumberFormatException e) { + throw new ExpressionEvaluationException( + String.format("Option '%s' must be a number, got '%s'", key, options.get(key))); + } + } + + private String getArgumentValue(String name) { + return arguments.stream() + .filter(arg -> ((NamedArgumentExpression) arg).getArgName().equalsIgnoreCase(name)) + .map(arg -> ((NamedArgumentExpression) arg).getValue().valueOf().stringValue()) + .findFirst() + .orElseThrow( + () -> + new ExpressionEvaluationException( + String.format("Missing required argument: %s", name))); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java new file mode 100644 index 00000000000..a23d5876fcd --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.List; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.opensearch.client.OpenSearchClient; + +@RequiredArgsConstructor +public class VectorSearchTableFunctionResolver implements FunctionResolver { + + public static final String VECTOR_SEARCH = "vectorsearch"; + public static final String TABLE = "table"; + public static final String FIELD = "field"; + public static final String VECTOR = "vector"; + public static final String OPTION = "option"; + public static final List ARGUMENT_NAMES = List.of(TABLE, FIELD, VECTOR, OPTION); + + private final OpenSearchClient client; + private final Settings settings; + + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + FunctionName functionName = FunctionName.of(VECTOR_SEARCH); + FunctionSignature functionSignature = + new FunctionSignature(functionName, List.of(STRING, STRING, STRING, STRING)); + FunctionBuilder functionBuilder = + (functionProperties, arguments) -> { + validateArguments(arguments); + return new VectorSearchTableFunctionImplementation( + functionName, arguments, client, settings); + }; + return Pair.of(functionSignature, functionBuilder); + } + + @Override + public FunctionName getFunctionName() { + return FunctionName.of(VECTOR_SEARCH); + } + + private void validateArguments(List arguments) { + if (arguments.size() != ARGUMENT_NAMES.size()) { + throw new IllegalArgumentException( + String.format( + "vectorSearch requires %d arguments (%s), got %d", + ARGUMENT_NAMES.size(), String.join(", ", ARGUMENT_NAMES), arguments.size())); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java index 70e6f0f2157..af9d46cd745 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java @@ -45,8 +45,8 @@ public OpenSearchIndexScanBuilder( this.scanFactory = scanFactory; } - /** Constructor used for unit tests. */ - protected OpenSearchIndexScanBuilder( + /** Constructor that accepts a custom PushDownQueryBuilder delegate. */ + public OpenSearchIndexScanBuilder( PushDownQueryBuilder translator, Function scanFactory) { this.delegate = translator; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java new file mode 100644 index 00000000000..efc2f333b0d --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; +import org.opensearch.sql.opensearch.storage.serde.DefaultExpressionSerializer; +import org.opensearch.sql.planner.logical.LogicalFilter; + +/** + * Query builder for vector search that keeps the knn query in a scoring (must) context and puts + * WHERE filters in a non-scoring (filter) context. This prevents the knn relevance scores from + * being destroyed when a WHERE clause is pushed down. + * + *

Without this, the default pushDownFilter wraps both queries into bool.filter, which is a + * non-scoring context. + */ +public class VectorSearchQueryBuilder extends OpenSearchIndexScanQueryBuilder { + + private final QueryBuilder knnQuery; + + public VectorSearchQueryBuilder(OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery) { + super(requestBuilder); + // Set knn as the initial query (scoring context) + requestBuilder.getSourceBuilder().query(knnQuery); + this.knnQuery = knnQuery; + } + + @Override + public boolean pushDownFilter(LogicalFilter filter) { + FilterQueryBuilder queryBuilder = new FilterQueryBuilder(new DefaultExpressionSerializer()); + Expression queryCondition = filter.getCondition(); + QueryBuilder whereQuery = queryBuilder.build(queryCondition); + + // Combine: knn in must (scores), WHERE in filter (no scoring impact) + BoolQueryBuilder combined = QueryBuilders.boolQuery().must(knnQuery).filter(whereQuery); + requestBuilder.getSourceBuilder().query(combined); + return true; + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java index 38f2ae495e0..fa04395e065 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java @@ -11,12 +11,14 @@ import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; import static org.opensearch.sql.utils.SystemIndexUtils.TABLE_INFO; +import java.util.Collection; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.storage.system.OpenSearchSystemIndex; import org.opensearch.sql.storage.Table; @@ -36,6 +38,15 @@ public void getTable() { assertAll(() -> assertNotNull(table), () -> assertTrue(table instanceof OpenSearchIndex)); } + @Test + public void getFunctionsReturnsVectorSearchResolver() { + OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); + Collection functions = engine.getFunctions(); + assertTrue( + functions.stream().anyMatch(f -> f instanceof VectorSearchTableFunctionResolver), + "getFunctions() should contain a VectorSearchTableFunctionResolver"); + } + @Test public void getSystemTable() { OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java new file mode 100644 index 00000000000..71f0bfa80af --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -0,0 +1,278 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.storage.Table; + +@ExtendWith(MockitoExtension.class) +class VectorSearchTableFunctionImplementationTest { + + @Mock private OpenSearchClient client; + + @Mock private Settings settings; + + @Test + void testValueOfThrows() { + VectorSearchTableFunctionImplementation impl = createImpl(); + UnsupportedOperationException ex = + assertThrows(UnsupportedOperationException.class, () -> impl.valueOf()); + assertTrue(ex.getMessage().contains("only supported in FROM clause")); + } + + @Test + void testType() { + VectorSearchTableFunctionImplementation impl = createImpl(); + assertEquals(ExprCoreType.STRUCT, impl.type()); + } + + @Test + void testToString() { + VectorSearchTableFunctionImplementation impl = createImpl(); + String str = impl.toString(); + assertTrue(str.contains("vectorsearch")); + assertTrue(str.contains("table=")); + assertTrue(str.contains("my-index")); + } + + @Test + void testApplyArguments() { + VectorSearchTableFunctionImplementation impl = createImpl(); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsWithBracketedVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsWithUnbracketedVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "1.0, 2.0, 3.0", "k=5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testUnknownOptionKeyThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10,method.ef_search=100"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Unknown option key")); + assertTrue(ex.getMessage().contains("method.ef_search")); + } + + @Test + void testApplyArgumentsWithMaxDistance() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=10.0"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsWithMinScore() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=0.5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testUnknownOptionKeyOnlyThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "not_a_key=100"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Unknown option key")); + } + + @Test + void testParseOptionsMultiple() { + Map opts = + VectorSearchTableFunctionImplementation.parseOptions("k=5,max_distance=10.0"); + assertEquals("5", opts.get("k")); + assertEquals("10.0", opts.get("max_distance")); + } + + @Test + void testMalformedOptionSegmentThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("k=5,badoption")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void testDuplicateOptionKeyThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("k=5,k=10")); + assertTrue(ex.getMessage().contains("Duplicate option key")); + } + + @Test + void testNoRequiredOptionThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", ""); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Missing required option")); + } + + @Test + void testEmptyVectorThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must not be empty")); + } + + @Test + void testMalformedVectorComponentThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, abc, 3.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid vector component")); + } + + @Test + void testNonFiniteVectorComponentThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, Infinity, 3.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testMissingArgumentThrows() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertEquals("Missing required argument: option", ex.getMessage()); + } + + @Test + void testInvalidFieldNameThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "field\"injection", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid field name")); + } + + @Test + void testNestedFieldNameAllowed() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "doc.embedding", "[1.0, 2.0]", "k=5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testNonNumericKThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=abc"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be an integer")); + } + + @Test + void testNonNumericMaxDistanceThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=notanumber"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a number")); + } + + @Test + void testInfiniteMinScoreThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=Infinity"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testNonNamedArgThrows() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = List.of(DSL.literal("my-index")); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("requires named arguments")); + } + + @Test + void testNullArgNameThrows() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument(null, DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("option", DSL.literal("k=5"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("requires named arguments")); + } + + private VectorSearchTableFunctionImplementation createImpl() { + return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5"); + } + + private VectorSearchTableFunctionImplementation createImplWithArgs( + String table, String field, String vector, String option) { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("table", DSL.literal(table)), + DSL.namedArgument("field", DSL.literal(field)), + DSL.namedArgument("vector", DSL.literal(vector)), + DSL.namedArgument("option", DSL.literal(option))); + return new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java new file mode 100644 index 00000000000..77efd0a6d88 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.List; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.opensearch.client.OpenSearchClient; + +@ExtendWith(MockitoExtension.class) +class VectorSearchTableFunctionResolverTest { + + @Mock private OpenSearchClient client; + + @Mock private Settings settings; + + @Mock private FunctionProperties functionProperties; + + @Test + void testResolve() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0, 3.0]")), + DSL.namedArgument("option", DSL.literal("k=5"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, resolver.getFunctionName()); + assertEquals(List.of(STRING, STRING, STRING, STRING), resolution.getKey().getParamTypeList()); + + TableFunctionImplementation impl = + (TableFunctionImplementation) resolution.getValue().apply(functionProperties, expressions); + assertTrue(impl instanceof VectorSearchTableFunctionImplementation); + } + + @Test + void testWrongArgumentCount() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + FunctionBuilder builder = resolution.getValue(); + + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("requires 4 arguments")); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java new file mode 100644 index 00000000000..2df785b41a0 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.Collections; +import org.junit.jupiter.api.Test; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalValues; + +class VectorSearchQueryBuilderTest { + + @Test + void knnQuerySetAsScoringQuery() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + + new VectorSearchQueryBuilder(requestBuilder, knnQuery); + + QueryBuilder query = requestBuilder.getSourceBuilder().query(); + assertTrue( + query instanceof WrapperQueryBuilder, + "knn query should be set directly as top-level query (scoring context)"); + } + + @Test + void pushDownFilterKeepsKnnInScoringContext() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery); + + // Simulate WHERE name = 'John' + var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + boolean pushed = builder.pushDownFilter(filter); + + assertTrue(pushed, "pushDownFilter should succeed"); + QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query(); + assertTrue(resultQuery instanceof BoolQueryBuilder, "Result should be a BoolQuery"); + BoolQueryBuilder boolQuery = (BoolQueryBuilder) resultQuery; + assertEquals(1, boolQuery.must().size(), "knn query should be in must (scoring context)"); + assertEquals(1, boolQuery.filter().size(), "WHERE predicate should be in filter (non-scoring)"); + assertTrue( + boolQuery.must().get(0) instanceof WrapperQueryBuilder, + "must clause should contain the original knn WrapperQueryBuilder"); + } + + private OpenSearchRequestBuilder createRequestBuilder() { + return new OpenSearchRequestBuilder( + mock(OpenSearchExprValueFactory.class), 10000, mock(Settings.class)); + } +}