-
Notifications
You must be signed in to change notification settings - Fork 211
[Feature] Add vector search execution pipeline for vectorSearch() table function #5320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
27fa80e
f8b402e
cd74478
e05fef7
dc7aeb6
0d92dc2
1ec9aa1
cbce86e
74c74ac
c82bc5f
a33e68c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.opensearch.storage; | ||
|
|
||
| import java.util.Map; | ||
| import java.util.function.Function; | ||
| import org.opensearch.common.unit.TimeValue; | ||
| import org.opensearch.index.query.QueryBuilder; | ||
| import org.opensearch.index.query.WrapperQueryBuilder; | ||
| import org.opensearch.sql.common.setting.Settings; | ||
| import org.opensearch.sql.opensearch.client.OpenSearchClient; | ||
| import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; | ||
| import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScan; | ||
| import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanBuilder; | ||
| import org.opensearch.sql.opensearch.storage.scan.VectorSearchQueryBuilder; | ||
| import org.opensearch.sql.storage.read.TableScanBuilder; | ||
|
|
||
| /** | ||
| * Vector-search-aware OpenSearch index. Seeds the scan with a knn query and enables score tracking. | ||
| */ | ||
| public class VectorSearchIndex extends OpenSearchIndex { | ||
|
|
||
| private final String field; | ||
| private final float[] vector; | ||
| private final Map<String, String> options; | ||
|
|
||
| public VectorSearchIndex( | ||
| OpenSearchClient client, | ||
| Settings settings, | ||
| String indexName, | ||
| String field, | ||
| float[] vector, | ||
| Map<String, String> options) { | ||
| super(client, settings, indexName); | ||
| this.field = field; | ||
| this.vector = vector; | ||
| this.options = options; | ||
| } | ||
|
|
||
| @Override | ||
| public TableScanBuilder createScanBuilder() { | ||
| final TimeValue cursorKeepAlive = | ||
| getSettings().getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); | ||
| var requestBuilder = createRequestBuilder(); | ||
|
|
||
| // Use VectorSearchQueryBuilder to keep knn in must (scoring) context. | ||
| // WHERE filters will be placed in filter (non-scoring) context. | ||
| var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery()); | ||
| requestBuilder.pushDownTrackedScore(true); | ||
|
|
||
| // Top-k mode: default size to k so queries without LIMIT return k results | ||
| // instead of falling into the generic large-scan path. | ||
| // LIMIT pushdown will further reduce this if present. | ||
| if (options.containsKey("k")) { | ||
| requestBuilder.pushDownLimitToRequestTotal(Integer.parseInt(options.get("k")), 0); | ||
| } | ||
|
|
||
| Function<OpenSearchRequestBuilder, OpenSearchIndexScan> createScanOperator = | ||
| rb -> | ||
| new OpenSearchIndexScan( | ||
| getClient(), | ||
| rb.getMaxResponseSize(), | ||
| rb.build(getIndexName(), cursorKeepAlive, getClient(), getFieldTypes().isEmpty())); | ||
| return new OpenSearchIndexScanBuilder(queryBuilder, createScanOperator); | ||
| } | ||
|
|
||
| private QueryBuilder buildKnnQuery() { | ||
| StringBuilder vectorJson = new StringBuilder("["); | ||
| for (int i = 0; i < vector.length; i++) { | ||
| if (i > 0) vectorJson.append(","); | ||
| vectorJson.append(vector[i]); | ||
| } | ||
| vectorJson.append("]"); | ||
|
|
||
| StringBuilder optionsJson = new StringBuilder(); | ||
| for (Map.Entry<String, String> entry : options.entrySet()) { | ||
| optionsJson.append(","); | ||
| String value = entry.getValue(); | ||
| // All P0 option values are canonicalized to numeric strings by validateOptions(). | ||
| // The quoted fallback is retained for forward compatibility with future non-numeric options. | ||
| if (isNumeric(value)) { | ||
| optionsJson.append(String.format("\"%s\":%s", entry.getKey(), value)); | ||
| } else { | ||
| optionsJson.append(String.format("\"%s\":\"%s\"", entry.getKey(), value)); | ||
| } | ||
| } | ||
|
|
||
| String knnQueryJson = | ||
| String.format( | ||
| "{\"knn\":{\"%s\":{\"vector\":%s%s}}}", | ||
| field, vectorJson.toString(), optionsJson.toString()); | ||
| return new WrapperQueryBuilder(knnQueryJson); | ||
| } | ||
|
|
||
| private static boolean isNumeric(String str) { | ||
| try { | ||
| Double.parseDouble(str); | ||
| return true; | ||
| } catch (NumberFormatException e) { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,246 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.sql.opensearch.storage; | ||
|
|
||
| import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.FIELD; | ||
| import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.OPTION; | ||
| import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.TABLE; | ||
| import static org.opensearch.sql.opensearch.storage.VectorSearchTableFunctionResolver.VECTOR; | ||
|
|
||
| import java.util.LinkedHashMap; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Set; | ||
| import java.util.regex.Pattern; | ||
| import java.util.stream.Collectors; | ||
| import org.opensearch.sql.common.setting.Settings; | ||
| import org.opensearch.sql.data.model.ExprValue; | ||
| import org.opensearch.sql.data.type.ExprCoreType; | ||
| import org.opensearch.sql.data.type.ExprType; | ||
| import org.opensearch.sql.exception.ExpressionEvaluationException; | ||
| import org.opensearch.sql.expression.Expression; | ||
| import org.opensearch.sql.expression.FunctionExpression; | ||
| import org.opensearch.sql.expression.NamedArgumentExpression; | ||
| import org.opensearch.sql.expression.env.Environment; | ||
| import org.opensearch.sql.expression.function.FunctionName; | ||
| import org.opensearch.sql.expression.function.TableFunctionImplementation; | ||
| import org.opensearch.sql.opensearch.client.OpenSearchClient; | ||
| import org.opensearch.sql.storage.Table; | ||
|
|
||
| public class VectorSearchTableFunctionImplementation extends FunctionExpression | ||
| implements TableFunctionImplementation { | ||
|
|
||
| /** P0 allowed option keys. Rejects unknown/future keys to prevent unvalidated DSL injection. */ | ||
| static final Set<String> ALLOWED_OPTION_KEYS = Set.of("k", "max_distance", "min_score"); | ||
|
|
||
| /** | ||
| * Field names must be safe for JSON interpolation: alphanumeric, dots (nested), underscores, | ||
| * hyphens. Rejects characters that could corrupt the WrapperQueryBuilder JSON. | ||
| */ | ||
| private static final Pattern SAFE_FIELD_NAME = Pattern.compile("^[a-zA-Z0-9._\\-]+$"); | ||
|
|
||
| private final FunctionName functionName; | ||
| private final List<Expression> arguments; | ||
| private final OpenSearchClient client; | ||
| private final Settings settings; | ||
|
|
||
| public VectorSearchTableFunctionImplementation( | ||
| FunctionName functionName, | ||
| List<Expression> arguments, | ||
| OpenSearchClient client, | ||
| Settings settings) { | ||
| super(functionName, arguments); | ||
| this.functionName = functionName; | ||
| this.arguments = arguments; | ||
| this.client = client; | ||
| this.settings = settings; | ||
| } | ||
|
|
||
| @Override | ||
| public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) { | ||
| throw new UnsupportedOperationException( | ||
| String.format("vectorSearch function [%s] is only supported in FROM clause", functionName)); | ||
| } | ||
|
|
||
| @Override | ||
| public ExprType type() { | ||
| return ExprCoreType.STRUCT; | ||
| } | ||
|
|
||
| @Override | ||
| public String toString() { | ||
| List<String> args = | ||
| arguments.stream() | ||
| .map( | ||
| arg -> { | ||
| if (arg instanceof NamedArgumentExpression) { | ||
| NamedArgumentExpression named = (NamedArgumentExpression) arg; | ||
| return String.format("%s=%s", named.getArgName(), named.getValue().toString()); | ||
| } | ||
| return arg.toString(); | ||
| }) | ||
| .collect(Collectors.toList()); | ||
| return String.format("%s(%s)", functionName, String.join(", ", args)); | ||
| } | ||
|
|
||
| @Override | ||
| public Table applyArguments() { | ||
| validateNamedArgs(); | ||
| String tableName = getArgumentValue(TABLE); | ||
| String fieldName = getArgumentValue(FIELD); | ||
| validateFieldName(fieldName); | ||
| String vectorLiteral = getArgumentValue(VECTOR); | ||
| String optionStr = getArgumentValue(OPTION); | ||
|
|
||
| float[] vector = parseVector(vectorLiteral); | ||
| Map<String, String> options = parseOptions(optionStr); | ||
| validateOptions(options); | ||
|
|
||
| return new VectorSearchIndex(client, settings, tableName, fieldName, vector, options); | ||
| } | ||
|
|
||
| private float[] parseVector(String vectorLiteral) { | ||
| String cleaned = vectorLiteral.replaceAll("[\\[\\]]", "").trim(); | ||
| if (cleaned.isEmpty()) { | ||
| throw new ExpressionEvaluationException("Vector literal must not be empty"); | ||
| } | ||
| String[] parts = cleaned.split(","); | ||
| float[] vector = new float[parts.length]; | ||
| for (int i = 0; i < parts.length; i++) { | ||
| try { | ||
| vector[i] = Float.parseFloat(parts[i].trim()); | ||
| } catch (NumberFormatException e) { | ||
| throw new ExpressionEvaluationException( | ||
| String.format("Invalid vector component '%s': must be a number", parts[i].trim())); | ||
| } | ||
| if (!Float.isFinite(vector[i])) { | ||
| throw new ExpressionEvaluationException( | ||
| String.format( | ||
| "Invalid vector component '%s': must be a finite number", parts[i].trim())); | ||
| } | ||
| } | ||
| return vector; | ||
| } | ||
|
|
||
| static Map<String, String> parseOptions(String optionStr) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any idea to simplify the parse/validate logic? Probably double check if some is already guarded by grammar or if it can be done by existing visit method.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked the grammar/visitor angle. We can lean on parser/resolver for the outer function shape and arity, but option still reaches this implementation as a single I can simplify the code shape, though: for example by extracting the option parsing/validation into a small helper/value object like VectorSearchOptions.parse(optionStr), so this class reads more like orchestration. |
||
| Map<String, String> options = new LinkedHashMap<>(); | ||
| for (String pair : optionStr.split(",")) { | ||
| String trimmed = pair.trim(); | ||
| if (trimmed.isEmpty()) { | ||
| continue; | ||
| } | ||
| String[] kv = trimmed.split("=", 2); | ||
| if (kv.length != 2 || kv[0].trim().isEmpty() || kv[1].trim().isEmpty()) { | ||
| throw new ExpressionEvaluationException( | ||
| String.format("Malformed option segment '%s': expected key=value", trimmed)); | ||
| } | ||
| String key = kv[0].trim(); | ||
| if (options.containsKey(key)) { | ||
| throw new ExpressionEvaluationException(String.format("Duplicate option key '%s'", key)); | ||
| } | ||
| options.put(key, kv[1].trim()); | ||
| } | ||
| return options; | ||
| } | ||
|
|
||
| /** Reject non-named arguments and null arg names early. */ | ||
| private void validateNamedArgs() { | ||
| for (Expression arg : arguments) { | ||
| if (!(arg instanceof NamedArgumentExpression)) { | ||
| throw new ExpressionEvaluationException( | ||
| "vectorSearch() requires named arguments (e.g., table='index'), " | ||
| + "but received: " | ||
| + arg.getClass().getSimpleName()); | ||
| } | ||
| String name = ((NamedArgumentExpression) arg).getArgName(); | ||
| if (name == null || name.isEmpty()) { | ||
| throw new ExpressionEvaluationException( | ||
| "vectorSearch() requires named arguments (e.g., table='index'), " | ||
| + "but received an argument with no name"); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Reject field names with characters that could corrupt the WrapperQueryBuilder JSON. Allows | ||
| * alphanumeric, dots (nested fields), underscores, and hyphens. | ||
| */ | ||
| private void validateFieldName(String fieldName) { | ||
| if (!SAFE_FIELD_NAME.matcher(fieldName).matches()) { | ||
| throw new ExpressionEvaluationException( | ||
| String.format( | ||
| "Invalid field name '%s': must contain only alphanumeric characters," | ||
| + " dots, underscores, or hyphens", | ||
| fieldName)); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Validates and canonicalizes option values. All P0 option values must be numeric. Parsing them | ||
| * here prevents non-numeric strings from reaching the raw JSON construction in buildKnnQuery(). | ||
| */ | ||
| private void validateOptions(Map<String, String> options) { | ||
| // Reject unknown option keys — only P0 keys are allowed | ||
| for (String key : options.keySet()) { | ||
| if (!ALLOWED_OPTION_KEYS.contains(key)) { | ||
| throw new ExpressionEvaluationException( | ||
| String.format("Unknown option key '%s'. Supported keys: %s", key, ALLOWED_OPTION_KEYS)); | ||
| } | ||
| } | ||
| boolean hasK = options.containsKey("k"); | ||
| boolean hasMaxDistance = options.containsKey("max_distance"); | ||
| boolean hasMinScore = options.containsKey("min_score"); | ||
| if (!hasK && !hasMaxDistance && !hasMinScore) { | ||
|
penghuo marked this conversation as resolved.
|
||
| throw new ExpressionEvaluationException( | ||
| "Missing required option: one of k, max_distance, or min_score"); | ||
| } | ||
| // Parse and canonicalize numeric values — closes JSON injection via option values | ||
| if (hasK) { | ||
| parseIntOption(options, "k"); | ||
| } | ||
| if (hasMaxDistance) { | ||
| parseDoubleOption(options, "max_distance"); | ||
| } | ||
| if (hasMinScore) { | ||
| parseDoubleOption(options, "min_score"); | ||
| } | ||
| } | ||
|
|
||
| private void parseIntOption(Map<String, String> options, String key) { | ||
| try { | ||
| int value = Integer.parseInt(options.get(key)); | ||
| options.put(key, Integer.toString(value)); | ||
| } catch (NumberFormatException e) { | ||
| throw new ExpressionEvaluationException( | ||
| String.format("Option '%s' must be an integer, got '%s'", key, options.get(key))); | ||
| } | ||
| } | ||
|
|
||
| private void parseDoubleOption(Map<String, String> options, String key) { | ||
| try { | ||
| double value = Double.parseDouble(options.get(key)); | ||
| if (!Double.isFinite(value)) { | ||
| throw new ExpressionEvaluationException( | ||
| String.format("Option '%s' must be a finite number, got '%s'", key, options.get(key))); | ||
| } | ||
| options.put(key, Double.toString(value)); | ||
| } catch (NumberFormatException e) { | ||
| throw new ExpressionEvaluationException( | ||
| String.format("Option '%s' must be a number, got '%s'", key, options.get(key))); | ||
| } | ||
| } | ||
|
|
||
| private String getArgumentValue(String name) { | ||
| return arguments.stream() | ||
| .filter(arg -> ((NamedArgumentExpression) arg).getArgName().equalsIgnoreCase(name)) | ||
| .map(arg -> ((NamedArgumentExpression) arg).getValue().valueOf().stringValue()) | ||
| .findFirst() | ||
| .orElseThrow( | ||
| () -> | ||
| new ExpressionEvaluationException( | ||
| String.format("Missing required argument: %s", name))); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of separate
VectorSearchQueryBuilderandVectorSearchIndex, just wonder can we abstract vector search translation logic byLuceneQueryinterface?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I split it this way because the two classes currently own different seams:
VectorSearchIndexseeds the scan/request state (knn query + track_scores + size=k default) whileVectorSearchQueryBuilderhandles score-preserving filter pushdown. I agree a shared abstraction could make sense if we see a second consumer or want a cleaner bridge for the unified SQL path, but I'd prefer not to turn PR-2 into a broader refactor unless there's an existing LuceneQuery contract that fits both responsibilities cleanly.