From 27fa80ea904eb7fe338ef03690eb37edf4d74142 Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Tue, 7 Apr 2026 15:40:53 -0700 Subject: [PATCH 01/11] Add knn_vector as recognized MappingType in OpenSearchDataType MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Maps knn_vector fields to ExprCoreType.ARRAY so they appear in DESCRIBE output and can be referenced in projections. This is a visibility shim — not a full vector type. Signed-off-by: Eric Wei --- .../sql/opensearch/data/type/OpenSearchDataType.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java index 837a2a062ef..79d49a143de 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataType.java @@ -43,7 +43,8 @@ public enum MappingType { ScaledFloat("scaled_float", ExprCoreType.DOUBLE), Double("double", ExprCoreType.DOUBLE), Boolean("boolean", ExprCoreType.BOOLEAN), - Alias("alias", ExprCoreType.UNKNOWN); + Alias("alias", ExprCoreType.UNKNOWN), + KnnVector("knn_vector", ExprCoreType.ARRAY); // TODO: ranges, geo shape, point, shape private final String name; From f8b402ea65813b734c42f49f9edf4b1e8b9ce488 Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Tue, 7 Apr 2026 15:42:26 -0700 Subject: [PATCH 02/11] Widen OpenSearchIndexScanBuilder constructor to public VectorSearchIndex.createScanBuilder() needs to construct an OpenSearchIndexScanBuilder with a custom VectorSearchQueryBuilder delegate. The existing constructor was protected (test-only). Signed-off-by: Eric Wei --- .../opensearch/storage/scan/OpenSearchIndexScanBuilder.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java index 70e6f0f2157..af9d46cd745 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java @@ -45,8 +45,8 @@ public OpenSearchIndexScanBuilder( this.scanFactory = scanFactory; } - /** Constructor used for unit tests. */ - protected OpenSearchIndexScanBuilder( + /** Constructor that accepts a custom PushDownQueryBuilder delegate. */ + public OpenSearchIndexScanBuilder( PushDownQueryBuilder translator, Function scanFactory) { this.delegate = translator; From cd74478f4941d3cd82a6b092387e6cd20aa56a70 Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Tue, 7 Apr 2026 15:59:51 -0700 Subject: [PATCH 03/11] Add vector search table function, query builder, and index Introduces the core execution pipeline for vectorsearch(): - VectorSearchTableFunctionResolver: registers vectorsearch with 4 STRING args - VectorSearchTableFunctionImplementation: parses named args, vector literal, options string, validates search mode (k/max_distance/min_score) - VectorSearchIndex: extends OpenSearchIndex with knn query seeding, score tracking, and WrapperQueryBuilder DSL construction - VectorSearchQueryBuilder: keeps knn in must (scoring) context, WHERE filters in filter (non-scoring) context Signed-off-by: Eric Wei --- .../opensearch/storage/VectorSearchIndex.java | 100 ++++++++++++ ...ctorSearchTableFunctionImplementation.java | 131 +++++++++++++++ .../VectorSearchTableFunctionResolver.java | 61 +++++++ .../scan/VectorSearchQueryBuilder.java | 47 ++++++ ...SearchTableFunctionImplementationTest.java | 151 ++++++++++++++++++ ...VectorSearchTableFunctionResolverTest.java | 86 ++++++++++ 6 files changed, 576 insertions(+) create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java new file mode 100644 index 00000000000..6705bc67d05 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import java.util.Map; +import java.util.function.Function; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanBuilder; +import org.opensearch.sql.opensearch.storage.scan.VectorSearchQueryBuilder; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Vector-search-aware OpenSearch index. Seeds the scan with a knn query and enables score tracking. + */ +public class VectorSearchIndex extends OpenSearchIndex { + + private static final String VECTOR_OPTION = "vector"; + + private final String field; + private final float[] vector; + private final Map options; + + public VectorSearchIndex( + OpenSearchClient client, + Settings settings, + String indexName, + String field, + float[] vector, + Map options) { + super(client, settings, indexName); + this.field = field; + this.vector = vector; + this.options = options; + } + + @Override + public TableScanBuilder createScanBuilder() { + final TimeValue cursorKeepAlive = + getSettings().getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); + var requestBuilder = createRequestBuilder(); + + // Use VectorSearchQueryBuilder to keep knn in must (scoring) context. + // WHERE filters will be placed in filter (non-scoring) context. + var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery()); + requestBuilder.pushDownTrackedScore(true); + + Function createScanOperator = + rb -> + new OpenSearchIndexScan( + getClient(), + rb.getMaxResponseSize(), + rb.build(getIndexName(), cursorKeepAlive, getClient(), getFieldTypes().isEmpty())); + return new OpenSearchIndexScanBuilder(queryBuilder, createScanOperator); + } + + private QueryBuilder buildKnnQuery() { + StringBuilder vectorJson = new StringBuilder("["); + for (int i = 0; i < vector.length; i++) { + if (i > 0) vectorJson.append(","); + vectorJson.append(vector[i]); + } + vectorJson.append("]"); + + StringBuilder optionsJson = new StringBuilder(); + for (Map.Entry entry : options.entrySet()) { + optionsJson.append(","); + String value = entry.getValue(); + // Numeric values go unquoted, everything else quoted + if (isNumeric(value)) { + optionsJson.append(String.format("\"%s\":%s", entry.getKey(), value)); + } else { + optionsJson.append(String.format("\"%s\":\"%s\"", entry.getKey(), value)); + } + } + + String knnQueryJson = + String.format( + "{\"knn\":{\"%s\":{\"vector\":%s%s}}}", + field, vectorJson.toString(), optionsJson.toString()); + return new WrapperQueryBuilder(knnQueryJson); + } + + private static boolean isNumeric(String str) { + try { + Double.parseDouble(str); + return true; + } catch (NumberFormatException e) { + return false; + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java new file mode 100644 index 00000000000..631fd194ce6 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.FIELD; +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.OPTION; +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.TABLE; +import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.VECTOR; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.storage.Table; + +public class VectorSearchTableFunctionImplementation extends FunctionExpression + implements TableFunctionImplementation { + + private final FunctionName functionName; + private final List arguments; + private final OpenSearchClient client; + private final Settings settings; + + public VectorSearchTableFunctionImplementation( + FunctionName functionName, + List arguments, + OpenSearchClient client, + Settings settings) { + super(functionName, arguments); + this.functionName = functionName; + this.arguments = arguments; + this.client = client; + this.settings = settings; + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new UnsupportedOperationException( + String.format("vectorSearch function [%s] is only supported in FROM clause", functionName)); + } + + @Override + public ExprType type() { + return ExprCoreType.STRUCT; + } + + @Override + public String toString() { + List args = + arguments.stream() + .map( + arg -> + String.format( + "%s=%s", + ((NamedArgumentExpression) arg).getArgName(), + ((NamedArgumentExpression) arg).getValue().toString())) + .collect(Collectors.toList()); + return String.format("%s(%s)", functionName, String.join(", ", args)); + } + + @Override + public Table applyArguments() { + String tableName = getArgumentValue(TABLE); + String fieldName = getArgumentValue(FIELD); + String vectorLiteral = getArgumentValue(VECTOR); + String optionStr = getArgumentValue(OPTION); + + float[] vector = parseVector(vectorLiteral); + Map options = parseOptions(optionStr); + validateOptions(options); + + return new VectorSearchIndex(client, settings, tableName, fieldName, vector, options); + } + + private float[] parseVector(String vectorLiteral) { + String cleaned = vectorLiteral.replaceAll("[\\[\\]]", "").trim(); + String[] parts = cleaned.split(","); + float[] vector = new float[parts.length]; + for (int i = 0; i < parts.length; i++) { + vector[i] = Float.parseFloat(parts[i].trim()); + } + return vector; + } + + static Map parseOptions(String optionStr) { + Map options = new LinkedHashMap<>(); + for (String pair : optionStr.split(",")) { + String[] kv = pair.trim().split("=", 2); + if (kv.length == 2) { + options.put(kv[0].trim(), kv[1].trim()); + } + } + return options; + } + + private void validateOptions(Map options) { + boolean hasK = options.containsKey("k"); + boolean hasMaxDistance = options.containsKey("max_distance"); + boolean hasMinScore = options.containsKey("min_score"); + if (!hasK && !hasMaxDistance && !hasMinScore) { + throw new ExpressionEvaluationException( + "Missing required option: one of k, max_distance, or min_score"); + } + } + + private String getArgumentValue(String name) { + return arguments.stream() + .filter(arg -> ((NamedArgumentExpression) arg).getArgName().equalsIgnoreCase(name)) + .map(arg -> ((NamedArgumentExpression) arg).getValue().valueOf().stringValue()) + .findFirst() + .orElseThrow( + () -> + new ExpressionEvaluationException( + String.format("Missing required argument: %s", name))); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java new file mode 100644 index 00000000000..a23d5876fcd --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolver.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.List; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.opensearch.client.OpenSearchClient; + +@RequiredArgsConstructor +public class VectorSearchTableFunctionResolver implements FunctionResolver { + + public static final String VECTOR_SEARCH = "vectorsearch"; + public static final String TABLE = "table"; + public static final String FIELD = "field"; + public static final String VECTOR = "vector"; + public static final String OPTION = "option"; + public static final List ARGUMENT_NAMES = List.of(TABLE, FIELD, VECTOR, OPTION); + + private final OpenSearchClient client; + private final Settings settings; + + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + FunctionName functionName = FunctionName.of(VECTOR_SEARCH); + FunctionSignature functionSignature = + new FunctionSignature(functionName, List.of(STRING, STRING, STRING, STRING)); + FunctionBuilder functionBuilder = + (functionProperties, arguments) -> { + validateArguments(arguments); + return new VectorSearchTableFunctionImplementation( + functionName, arguments, client, settings); + }; + return Pair.of(functionSignature, functionBuilder); + } + + @Override + public FunctionName getFunctionName() { + return FunctionName.of(VECTOR_SEARCH); + } + + private void validateArguments(List arguments) { + if (arguments.size() != ARGUMENT_NAMES.size()) { + throw new IllegalArgumentException( + String.format( + "vectorSearch requires %d arguments (%s), got %d", + ARGUMENT_NAMES.size(), String.join(", ", ARGUMENT_NAMES), arguments.size())); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java new file mode 100644 index 00000000000..efc2f333b0d --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; +import org.opensearch.sql.opensearch.storage.serde.DefaultExpressionSerializer; +import org.opensearch.sql.planner.logical.LogicalFilter; + +/** + * Query builder for vector search that keeps the knn query in a scoring (must) context and puts + * WHERE filters in a non-scoring (filter) context. This prevents the knn relevance scores from + * being destroyed when a WHERE clause is pushed down. + * + *

Without this, the default pushDownFilter wraps both queries into bool.filter, which is a + * non-scoring context. + */ +public class VectorSearchQueryBuilder extends OpenSearchIndexScanQueryBuilder { + + private final QueryBuilder knnQuery; + + public VectorSearchQueryBuilder(OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery) { + super(requestBuilder); + // Set knn as the initial query (scoring context) + requestBuilder.getSourceBuilder().query(knnQuery); + this.knnQuery = knnQuery; + } + + @Override + public boolean pushDownFilter(LogicalFilter filter) { + FilterQueryBuilder queryBuilder = new FilterQueryBuilder(new DefaultExpressionSerializer()); + Expression queryCondition = filter.getCondition(); + QueryBuilder whereQuery = queryBuilder.build(queryCondition); + + // Combine: knn in must (scores), WHERE in filter (no scoring impact) + BoolQueryBuilder combined = QueryBuilders.boolQuery().must(knnQuery).filter(whereQuery); + requestBuilder.getSourceBuilder().query(combined); + return true; + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java new file mode 100644 index 00000000000..326f4c42991 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -0,0 +1,151 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.client.OpenSearchClient; +import org.opensearch.sql.storage.Table; + +@ExtendWith(MockitoExtension.class) +class VectorSearchTableFunctionImplementationTest { + + @Mock private OpenSearchClient client; + + @Mock private Settings settings; + + @Test + void testValueOfThrows() { + VectorSearchTableFunctionImplementation impl = createImpl(); + UnsupportedOperationException ex = + assertThrows(UnsupportedOperationException.class, () -> impl.valueOf()); + assertTrue(ex.getMessage().contains("only supported in FROM clause")); + } + + @Test + void testType() { + VectorSearchTableFunctionImplementation impl = createImpl(); + assertEquals(ExprCoreType.STRUCT, impl.type()); + } + + @Test + void testToString() { + VectorSearchTableFunctionImplementation impl = createImpl(); + String str = impl.toString(); + assertTrue(str.contains("vectorsearch")); + assertTrue(str.contains("table=")); + assertTrue(str.contains("my-index")); + } + + @Test + void testApplyArguments() { + VectorSearchTableFunctionImplementation impl = createImpl(); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsWithBracketedVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsWithUnbracketedVector() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "1.0, 2.0, 3.0", "k=5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsWithComplexOptions() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10,method.ef_search=100"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsWithMaxDistance() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=10.0"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testApplyArgumentsWithMinScore() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=0.5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testMissingSearchModeOptionThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "method.ef_search=100"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("one of k, max_distance, or min_score")); + } + + @Test + void testParseOptionsMultiple() { + Map opts = + VectorSearchTableFunctionImplementation.parseOptions("k=5,method.ef_search=100"); + assertEquals("5", opts.get("k")); + assertEquals("100", opts.get("method.ef_search")); + } + + @Test + void testMissingArgumentThrows() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertEquals("Missing required argument: option", ex.getMessage()); + } + + private VectorSearchTableFunctionImplementation createImpl() { + return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5"); + } + + private VectorSearchTableFunctionImplementation createImplWithArgs( + String table, String field, String vector, String option) { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("table", DSL.literal(table)), + DSL.namedArgument("field", DSL.literal(field)), + DSL.namedArgument("vector", DSL.literal(vector)), + DSL.namedArgument("option", DSL.literal(option))); + return new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java new file mode 100644 index 00000000000..77efd0a6d88 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.List; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.opensearch.client.OpenSearchClient; + +@ExtendWith(MockitoExtension.class) +class VectorSearchTableFunctionResolverTest { + + @Mock private OpenSearchClient client; + + @Mock private Settings settings; + + @Mock private FunctionProperties functionProperties; + + @Test + void testResolve() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0, 3.0]")), + DSL.namedArgument("option", DSL.literal("k=5"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, resolver.getFunctionName()); + assertEquals(List.of(STRING, STRING, STRING, STRING), resolution.getKey().getParamTypeList()); + + TableFunctionImplementation impl = + (TableFunctionImplementation) resolution.getValue().apply(functionProperties, expressions); + assertTrue(impl instanceof VectorSearchTableFunctionImplementation); + } + + @Test + void testWrongArgumentCount() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + FunctionBuilder builder = resolution.getValue(); + + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("requires 4 arguments")); + } +} From e05fef7ca703d38f5fa6e669e7bbb6eb8f277d80 Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Tue, 7 Apr 2026 16:06:37 -0700 Subject: [PATCH 04/11] Register VectorSearchTableFunctionResolver in OpenSearchStorageEngine Override getFunctions() to expose vectorsearch() table function to the query analysis pipeline. Signed-off-by: Eric Wei --- .../opensearch/storage/OpenSearchStorageEngine.java | 8 ++++++++ .../storage/OpenSearchStorageEngineTest.java | 11 +++++++++++ 2 files changed, 19 insertions(+) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java index ce6740cd784..1b7de315fb6 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngine.java @@ -7,10 +7,13 @@ import static org.opensearch.sql.utils.SystemIndexUtils.isSystemIndex; +import java.util.Collection; +import java.util.List; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.storage.system.OpenSearchSystemIndex; import org.opensearch.sql.storage.StorageEngine; @@ -25,6 +28,11 @@ public class OpenSearchStorageEngine implements StorageEngine { @Getter private final Settings settings; + @Override + public Collection getFunctions() { + return List.of(new VectorSearchTableFunctionResolver(client, settings)); + } + @Override public Table getTable(DataSourceSchemaName dataSourceSchemaName, String name) { if (isSystemIndex(name)) { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java index 38f2ae495e0..0ed7ce31675 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java @@ -6,17 +6,20 @@ package org.opensearch.sql.opensearch.storage; import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; import static org.opensearch.sql.utils.SystemIndexUtils.TABLE_INFO; +import java.util.Collection; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.DataSourceSchemaName; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.storage.system.OpenSearchSystemIndex; import org.opensearch.sql.storage.Table; @@ -36,6 +39,14 @@ public void getTable() { assertAll(() -> assertNotNull(table), () -> assertTrue(table instanceof OpenSearchIndex)); } + @Test + public void getFunctionsReturnsVectorSearchResolver() { + OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); + Collection functions = engine.getFunctions(); + assertEquals(1, functions.size()); + assertTrue(functions.iterator().next() instanceof VectorSearchTableFunctionResolver); + } + @Test public void getSystemTable() { OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); From dc7aeb62280102a7125a05c3f737d472f5ed2668 Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Tue, 7 Apr 2026 16:07:19 -0700 Subject: [PATCH 05/11] Add VectorSearchQueryBuilder unit tests Verifies knn query is placed in scoring (must) context, not wrapped in bool.filter when no WHERE clause is present. Signed-off-by: Eric Wei --- .../scan/VectorSearchQueryBuilderTest.java | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java new file mode 100644 index 00000000000..88d834e1810 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +import org.junit.jupiter.api.Test; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; + +class VectorSearchQueryBuilderTest { + + @Test + void knnQuerySetAsScoringQuery() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("eyJrbm4iOnt9fQ=="); + + new VectorSearchQueryBuilder(requestBuilder, knnQuery); + + QueryBuilder query = requestBuilder.getSourceBuilder().query(); + assertTrue( + query instanceof WrapperQueryBuilder, + "knn query should be set directly as top-level query (scoring context)"); + } + + @Test + void knnQueryNotWrappedInFilterWhenNoWhere() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("eyJrbm4iOnt9fQ=="); + + new VectorSearchQueryBuilder(requestBuilder, knnQuery); + + QueryBuilder query = requestBuilder.getSourceBuilder().query(); + assertTrue( + query instanceof WrapperQueryBuilder, + "Without WHERE clause, knn query should NOT be wrapped in bool.filter"); + } + + private OpenSearchRequestBuilder createRequestBuilder() { + return new OpenSearchRequestBuilder( + mock(OpenSearchExprValueFactory.class), 10000, mock(Settings.class)); + } +} From 0d92dc2bc22431f00812862eaa5b1d18bf722cbb Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Tue, 7 Apr 2026 17:34:55 -0700 Subject: [PATCH 06/11] Address review feedback: add validation guards and pushDownFilter test - Add pushDownFilter() unit test asserting knn stays in bool.must (scoring) and WHERE predicate goes to bool.filter (non-scoring) - Add option key allowlist (k, max_distance, min_score) to reject unknown/unsupported keys before they reach DSL generation - Add field name validation to reject characters that could corrupt the WrapperQueryBuilder JSON (allows alphanumeric, dots, underscores, hyphens) - Add named-arg type guard to reject non-NamedArgumentExpression args early with a clear error message Signed-off-by: Eric Wei --- ...ctorSearchTableFunctionImplementation.java | 46 +++++++++++++++++++ ...SearchTableFunctionImplementationTest.java | 44 +++++++++++++++--- .../scan/VectorSearchQueryBuilderTest.java | 33 ++++++++++--- 3 files changed, 109 insertions(+), 14 deletions(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java index 631fd194ce6..2870daf4a5b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java @@ -13,6 +13,8 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.regex.Pattern; import java.util.stream.Collectors; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; @@ -31,6 +33,15 @@ public class VectorSearchTableFunctionImplementation extends FunctionExpression implements TableFunctionImplementation { + /** P0 allowed option keys. Rejects unknown/future keys to prevent unvalidated DSL injection. */ + static final Set ALLOWED_OPTION_KEYS = Set.of("k", "max_distance", "min_score"); + + /** + * Field names must be safe for JSON interpolation: alphanumeric, dots (nested), underscores, + * hyphens. Rejects characters that could corrupt the WrapperQueryBuilder JSON. + */ + private static final Pattern SAFE_FIELD_NAME = Pattern.compile("^[a-zA-Z0-9._\\-]+$"); + private final FunctionName functionName; private final List arguments; private final OpenSearchClient client; @@ -75,8 +86,10 @@ public String toString() { @Override public Table applyArguments() { + validateNamedArgs(); String tableName = getArgumentValue(TABLE); String fieldName = getArgumentValue(FIELD); + validateFieldName(fieldName); String vectorLiteral = getArgumentValue(VECTOR); String optionStr = getArgumentValue(OPTION); @@ -108,7 +121,40 @@ static Map parseOptions(String optionStr) { return options; } + /** Reject non-named arguments early. vectorSearch() requires named args (key=value). */ + private void validateNamedArgs() { + for (Expression arg : arguments) { + if (!(arg instanceof NamedArgumentExpression)) { + throw new ExpressionEvaluationException( + "vectorSearch() requires named arguments (e.g., table='index'), " + + "but received: " + + arg.getClass().getSimpleName()); + } + } + } + + /** + * Reject field names with characters that could corrupt the WrapperQueryBuilder JSON. Allows + * alphanumeric, dots (nested fields), underscores, and hyphens. + */ + private void validateFieldName(String fieldName) { + if (!SAFE_FIELD_NAME.matcher(fieldName).matches()) { + throw new ExpressionEvaluationException( + String.format( + "Invalid field name '%s': must contain only alphanumeric characters," + + " dots, underscores, or hyphens", + fieldName)); + } + } + private void validateOptions(Map options) { + // Reject unknown option keys — only P0 keys are allowed + for (String key : options.keySet()) { + if (!ALLOWED_OPTION_KEYS.contains(key)) { + throw new ExpressionEvaluationException( + String.format("Unknown option key '%s'. Supported keys: %s", key, ALLOWED_OPTION_KEYS)); + } + } boolean hasK = options.containsKey("k"); boolean hasMaxDistance = options.containsKey("max_distance"); boolean hasMinScore = options.containsKey("min_score"); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java index 326f4c42991..f55e14955bc 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -78,11 +78,13 @@ void testApplyArgumentsWithUnbracketedVector() { } @Test - void testApplyArgumentsWithComplexOptions() { + void testUnknownOptionKeyThrows() { VectorSearchTableFunctionImplementation impl = createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10,method.ef_search=100"); - Table table = impl.applyArguments(); - assertTrue(table instanceof VectorSearchIndex); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Unknown option key")); + assertTrue(ex.getMessage().contains("method.ef_search")); } @Test @@ -104,18 +106,18 @@ void testApplyArgumentsWithMinScore() { @Test void testMissingSearchModeOptionThrows() { VectorSearchTableFunctionImplementation impl = - createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "method.ef_search=100"); + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "not_a_key=100"); ExpressionEvaluationException ex = assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); - assertTrue(ex.getMessage().contains("one of k, max_distance, or min_score")); + assertTrue(ex.getMessage().contains("Unknown option key")); } @Test void testParseOptionsMultiple() { Map opts = - VectorSearchTableFunctionImplementation.parseOptions("k=5,method.ef_search=100"); + VectorSearchTableFunctionImplementation.parseOptions("k=5,max_distance=10.0"); assertEquals("5", opts.get("k")); - assertEquals("100", opts.get("method.ef_search")); + assertEquals("10.0", opts.get("max_distance")); } @Test @@ -133,6 +135,34 @@ void testMissingArgumentThrows() { assertEquals("Missing required argument: option", ex.getMessage()); } + @Test + void testInvalidFieldNameThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "field\"injection", "[1.0, 2.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid field name")); + } + + @Test + void testNestedFieldNameAllowed() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "doc.embedding", "[1.0, 2.0]", "k=5"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testNonNamedArgThrows() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = List.of(DSL.literal("my-index")); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("requires named arguments")); + } + private VectorSearchTableFunctionImplementation createImpl() { return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5"); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java index 88d834e1810..2df785b41a0 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java @@ -5,22 +5,30 @@ package org.opensearch.sql.opensearch.storage.scan; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import java.util.Collections; import org.junit.jupiter.api.Test; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.WrapperQueryBuilder; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalValues; class VectorSearchQueryBuilderTest { @Test void knnQuerySetAsScoringQuery() { var requestBuilder = createRequestBuilder(); - var knnQuery = new WrapperQueryBuilder("eyJrbm4iOnt9fQ=="); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); new VectorSearchQueryBuilder(requestBuilder, knnQuery); @@ -31,16 +39,27 @@ void knnQuerySetAsScoringQuery() { } @Test - void knnQueryNotWrappedInFilterWhenNoWhere() { + void pushDownFilterKeepsKnnInScoringContext() { var requestBuilder = createRequestBuilder(); - var knnQuery = new WrapperQueryBuilder("eyJrbm4iOnt9fQ=="); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery); - new VectorSearchQueryBuilder(requestBuilder, knnQuery); + // Simulate WHERE name = 'John' + var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); - QueryBuilder query = requestBuilder.getSourceBuilder().query(); + boolean pushed = builder.pushDownFilter(filter); + + assertTrue(pushed, "pushDownFilter should succeed"); + QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query(); + assertTrue(resultQuery instanceof BoolQueryBuilder, "Result should be a BoolQuery"); + BoolQueryBuilder boolQuery = (BoolQueryBuilder) resultQuery; + assertEquals(1, boolQuery.must().size(), "knn query should be in must (scoring context)"); + assertEquals(1, boolQuery.filter().size(), "WHERE predicate should be in filter (non-scoring)"); assertTrue( - query instanceof WrapperQueryBuilder, - "Without WHERE clause, knn query should NOT be wrapped in bool.filter"); + boolQuery.must().get(0) instanceof WrapperQueryBuilder, + "must clause should contain the original knn WrapperQueryBuilder"); } private OpenSearchRequestBuilder createRequestBuilder() { From 1ec9aa1984aadf42d7de5a688a08e003f209877d Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Wed, 8 Apr 2026 10:31:40 -0700 Subject: [PATCH 07/11] Canonicalize option values as numeric types before DSL generation Parse k as integer, max_distance and min_score as double before they reach buildKnnQuery(). Rejects non-numeric and non-finite values with clear errors. This closes the residual JSON-injection path through option values without requiring full XContent migration. Also fixes toString() to be consistent with the named-arg guard (no longer blindly casts to NamedArgumentExpression). Signed-off-by: Eric Wei --- ...ctorSearchTableFunctionImplementation.java | 50 +++++++++++++++++-- ...SearchTableFunctionImplementationTest.java | 27 ++++++++++ 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java index 2870daf4a5b..25f0a9ea9ab 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java @@ -75,11 +75,13 @@ public String toString() { List args = arguments.stream() .map( - arg -> - String.format( - "%s=%s", - ((NamedArgumentExpression) arg).getArgName(), - ((NamedArgumentExpression) arg).getValue().toString())) + arg -> { + if (arg instanceof NamedArgumentExpression) { + NamedArgumentExpression named = (NamedArgumentExpression) arg; + return String.format("%s=%s", named.getArgName(), named.getValue().toString()); + } + return arg.toString(); + }) .collect(Collectors.toList()); return String.format("%s(%s)", functionName, String.join(", ", args)); } @@ -147,6 +149,10 @@ private void validateFieldName(String fieldName) { } } + /** + * Validates and canonicalizes option values. All P0 option values must be numeric. Parsing them + * here prevents non-numeric strings from reaching the raw JSON construction in buildKnnQuery(). + */ private void validateOptions(Map options) { // Reject unknown option keys — only P0 keys are allowed for (String key : options.keySet()) { @@ -162,6 +168,40 @@ private void validateOptions(Map options) { throw new ExpressionEvaluationException( "Missing required option: one of k, max_distance, or min_score"); } + // Parse and canonicalize numeric values — closes JSON injection via option values + if (hasK) { + parseIntOption(options, "k"); + } + if (hasMaxDistance) { + parseDoubleOption(options, "max_distance"); + } + if (hasMinScore) { + parseDoubleOption(options, "min_score"); + } + } + + private void parseIntOption(Map options, String key) { + try { + int value = Integer.parseInt(options.get(key)); + options.put(key, Integer.toString(value)); + } catch (NumberFormatException e) { + throw new ExpressionEvaluationException( + String.format("Option '%s' must be an integer, got '%s'", key, options.get(key))); + } + } + + private void parseDoubleOption(Map options, String key) { + try { + double value = Double.parseDouble(options.get(key)); + if (!Double.isFinite(value)) { + throw new ExpressionEvaluationException( + String.format("Option '%s' must be a finite number, got '%s'", key, options.get(key))); + } + options.put(key, Double.toString(value)); + } catch (NumberFormatException e) { + throw new ExpressionEvaluationException( + String.format("Option '%s' must be a number, got '%s'", key, options.get(key))); + } } private String getArgumentValue(String name) { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java index f55e14955bc..b3bf7f24b9c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -152,6 +152,33 @@ void testNestedFieldNameAllowed() { assertTrue(table instanceof VectorSearchIndex); } + @Test + void testNonNumericKThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=abc"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be an integer")); + } + + @Test + void testNonNumericMaxDistanceThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=notanumber"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a number")); + } + + @Test + void testInfiniteMinScoreThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=Infinity"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + @Test void testNonNamedArgThrows() { FunctionName functionName = FunctionName.of("vectorsearch"); From cbce86e7200143c33a07f9df407232af88a6cf9e Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Wed, 8 Apr 2026 11:16:16 -0700 Subject: [PATCH 08/11] Harden input validation and add size=k default for top-k mode - parseOptions: reject malformed segments and duplicate keys - parseVector: wrap errors in ExpressionEvaluationException, reject non-finite floats (Infinity, NaN) - VectorSearchIndex: default requestedTotalSize to k via pushDownLimitToRequestTotal so queries without LIMIT return k results - Add 5 new tests: malformed option, duplicate key, empty vector, malformed vector component, non-finite vector component Signed-off-by: Eric Wei --- .../opensearch/storage/VectorSearchIndex.java | 7 +++ ...ctorSearchTableFunctionImplementation.java | 31 +++++++++++-- ...SearchTableFunctionImplementationTest.java | 45 +++++++++++++++++++ 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java index 6705bc67d05..26e94724a84 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java @@ -53,6 +53,13 @@ public TableScanBuilder createScanBuilder() { var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery()); requestBuilder.pushDownTrackedScore(true); + // Top-k mode: default size to k so queries without LIMIT return k results + // instead of falling into the generic large-scan path. + // LIMIT pushdown will further reduce this if present. + if (options.containsKey("k")) { + requestBuilder.pushDownLimitToRequestTotal(Integer.parseInt(options.get("k")), 0); + } + Function createScanOperator = rb -> new OpenSearchIndexScan( diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java index 25f0a9ea9ab..a69bb890d57 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java @@ -104,10 +104,23 @@ public Table applyArguments() { private float[] parseVector(String vectorLiteral) { String cleaned = vectorLiteral.replaceAll("[\\[\\]]", "").trim(); + if (cleaned.isEmpty()) { + throw new ExpressionEvaluationException("Vector literal must not be empty"); + } String[] parts = cleaned.split(","); float[] vector = new float[parts.length]; for (int i = 0; i < parts.length; i++) { - vector[i] = Float.parseFloat(parts[i].trim()); + try { + vector[i] = Float.parseFloat(parts[i].trim()); + } catch (NumberFormatException e) { + throw new ExpressionEvaluationException( + String.format("Invalid vector component '%s': must be a number", parts[i].trim())); + } + if (!Float.isFinite(vector[i])) { + throw new ExpressionEvaluationException( + String.format( + "Invalid vector component '%s': must be a finite number", parts[i].trim())); + } } return vector; } @@ -115,10 +128,20 @@ private float[] parseVector(String vectorLiteral) { static Map parseOptions(String optionStr) { Map options = new LinkedHashMap<>(); for (String pair : optionStr.split(",")) { - String[] kv = pair.trim().split("=", 2); - if (kv.length == 2) { - options.put(kv[0].trim(), kv[1].trim()); + String trimmed = pair.trim(); + if (trimmed.isEmpty()) { + continue; + } + String[] kv = trimmed.split("=", 2); + if (kv.length != 2 || kv[0].trim().isEmpty() || kv[1].trim().isEmpty()) { + throw new ExpressionEvaluationException( + String.format("Malformed option segment '%s': expected key=value", trimmed)); + } + String key = kv[0].trim(); + if (options.containsKey(key)) { + throw new ExpressionEvaluationException(String.format("Duplicate option key '%s'", key)); } + options.put(key, kv[1].trim()); } return options; } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java index b3bf7f24b9c..4ca4c997332 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -120,6 +120,51 @@ void testParseOptionsMultiple() { assertEquals("10.0", opts.get("max_distance")); } + @Test + void testMalformedOptionSegmentThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("k=5,badoption")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void testDuplicateOptionKeyThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("k=5,k=10")); + assertTrue(ex.getMessage().contains("Duplicate option key")); + } + + @Test + void testEmptyVectorThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must not be empty")); + } + + @Test + void testMalformedVectorComponentThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, abc, 3.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Invalid vector component")); + } + + @Test + void testNonFiniteVectorComponentThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, Infinity, 3.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + @Test void testMissingArgumentThrows() { FunctionName functionName = FunctionName.of("vectorsearch"); From 74c74ac79c271fac90b0cfea36242e8e1435731f Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Wed, 8 Apr 2026 11:53:55 -0700 Subject: [PATCH 09/11] Add null-arg-name guard and make storage engine test less brittle - validateNamedArgs() now rejects null/empty arg names defensively, closing a potential NPE if the shared table-function path is later wired into PPL - OpenSearchStorageEngineTest uses contains-check instead of exact collection size assertion - Add testNullArgNameThrows test Signed-off-by: Eric Wei --- .../VectorSearchTableFunctionImplementation.java | 8 +++++++- .../storage/OpenSearchStorageEngineTest.java | 6 +++--- ...torSearchTableFunctionImplementationTest.java | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java index a69bb890d57..c4c383f9623 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java @@ -146,7 +146,7 @@ static Map parseOptions(String optionStr) { return options; } - /** Reject non-named arguments early. vectorSearch() requires named args (key=value). */ + /** Reject non-named arguments and null arg names early. */ private void validateNamedArgs() { for (Expression arg : arguments) { if (!(arg instanceof NamedArgumentExpression)) { @@ -155,6 +155,12 @@ private void validateNamedArgs() { + "but received: " + arg.getClass().getSimpleName()); } + String name = ((NamedArgumentExpression) arg).getArgName(); + if (name == null || name.isEmpty()) { + throw new ExpressionEvaluationException( + "vectorSearch() requires named arguments (e.g., table='index'), " + + "but received an argument with no name"); + } } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java index 0ed7ce31675..fa04395e065 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchStorageEngineTest.java @@ -6,7 +6,6 @@ package org.opensearch.sql.opensearch.storage; import static org.junit.jupiter.api.Assertions.assertAll; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.analysis.DataSourceSchemaIdentifierNameResolver.DEFAULT_DATASOURCE_NAME; @@ -43,8 +42,9 @@ public void getTable() { public void getFunctionsReturnsVectorSearchResolver() { OpenSearchStorageEngine engine = new OpenSearchStorageEngine(client, settings); Collection functions = engine.getFunctions(); - assertEquals(1, functions.size()); - assertTrue(functions.iterator().next() instanceof VectorSearchTableFunctionResolver); + assertTrue( + functions.stream().anyMatch(f -> f instanceof VectorSearchTableFunctionResolver), + "getFunctions() should contain a VectorSearchTableFunctionResolver"); } @Test diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java index 4ca4c997332..19a9da5f32e 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -235,6 +235,22 @@ void testNonNamedArgThrows() { assertTrue(ex.getMessage().contains("requires named arguments")); } + @Test + void testNullArgNameThrows() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument(null, DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("option", DSL.literal("k=5"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("requires named arguments")); + } + private VectorSearchTableFunctionImplementation createImpl() { return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5"); } From c82bc5f95783c405f03d5141f55f494c6a8d80d0 Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Wed, 8 Apr 2026 12:01:57 -0700 Subject: [PATCH 10/11] Clean up dead code, fix misleading comment and test name - Remove unused VECTOR_OPTION constant from VectorSearchIndex - Clarify buildKnnQuery() comment: quoted fallback is for forward compatibility, all P0 values are already canonicalized as numeric - Rename testMissingSearchModeOptionThrows to testUnknownOptionKeyOnlyThrows to match what it actually tests Signed-off-by: Eric Wei --- .../opensearch/sql/opensearch/storage/VectorSearchIndex.java | 5 ++--- .../storage/VectorSearchTableFunctionImplementationTest.java | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java index 26e94724a84..f33d5f2fa73 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java @@ -23,8 +23,6 @@ */ public class VectorSearchIndex extends OpenSearchIndex { - private static final String VECTOR_OPTION = "vector"; - private final String field; private final float[] vector; private final Map options; @@ -81,7 +79,8 @@ private QueryBuilder buildKnnQuery() { for (Map.Entry entry : options.entrySet()) { optionsJson.append(","); String value = entry.getValue(); - // Numeric values go unquoted, everything else quoted + // All P0 option values are canonicalized to numeric strings by validateOptions(). + // The quoted fallback is retained for forward compatibility with future non-numeric options. if (isNumeric(value)) { optionsJson.append(String.format("\"%s\":%s", entry.getKey(), value)); } else { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java index 19a9da5f32e..c703d6639bc 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -104,7 +104,7 @@ void testApplyArgumentsWithMinScore() { } @Test - void testMissingSearchModeOptionThrows() { + void testUnknownOptionKeyOnlyThrows() { VectorSearchTableFunctionImplementation impl = createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "not_a_key=100"); ExpressionEvaluationException ex = From a33e68c61b23464cdbf7b4fd1e0db87e58b5f8fe Mon Sep 17 00:00:00 2001 From: Eric Wei Date: Wed, 8 Apr 2026 14:41:49 -0700 Subject: [PATCH 11/11] Add test for missing required option validation path Signed-off-by: Eric Wei --- .../VectorSearchTableFunctionImplementationTest.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java index c703d6639bc..71f0bfa80af 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -138,6 +138,15 @@ void testDuplicateOptionKeyThrows() { assertTrue(ex.getMessage().contains("Duplicate option key")); } + @Test + void testNoRequiredOptionThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", ""); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Missing required option")); + } + @Test void testEmptyVectorThrows() { VectorSearchTableFunctionImplementation impl =